Skip to content

Commit ab33e0f

Browse files
committed
Fix kernel cache miss and add RDNA configs
- added Navi configurations (Related PR: ROCm/triton#640) - resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0
1 parent da8f61a commit ab33e0f

File tree

1 file changed

+134
-52
lines changed

1 file changed

+134
-52
lines changed

vllm/attention/ops/triton_flash_attention.py

Lines changed: 134 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import triton
2525
import triton.language as tl
2626

27+
from vllm.utils import is_navi
28+
2729
torch_dtype: tl.constexpr = torch.float16
2830

2931

@@ -217,88 +219,80 @@ def _attn_fwd_inner(
217219
return acc, l_i, m_i
218220

219221

220-
@triton.autotune(
221-
configs=[
222+
def get_cdna_autotune_configs():
223+
return [
222224
triton.Config(
223225
{
224-
"BLOCK_M": 256,
225-
"BLOCK_N": 64,
226-
"waves_per_eu": 2,
227-
"PRE_LOAD_V": False,
226+
'BLOCK_M': 256,
227+
'BLOCK_N': 64,
228+
'waves_per_eu': 2,
229+
'PRE_LOAD_V': False
228230
},
229231
num_stages=1,
230-
num_warps=8,
231-
),
232+
num_warps=8),
232233
triton.Config(
233234
{
234-
"BLOCK_M": 128,
235-
"BLOCK_N": 128,
236-
"waves_per_eu": 2,
237-
"PRE_LOAD_V": False,
235+
'BLOCK_M': 128,
236+
'BLOCK_N': 128,
237+
'waves_per_eu': 2,
238+
'PRE_LOAD_V': False
238239
},
239240
num_stages=1,
240-
num_warps=4,
241-
),
241+
num_warps=4),
242242
triton.Config(
243243
{
244-
"BLOCK_M": 256,
245-
"BLOCK_N": 128,
246-
"waves_per_eu": 2,
247-
"PRE_LOAD_V": False,
244+
'BLOCK_M': 256,
245+
'BLOCK_N': 128,
246+
'waves_per_eu': 2,
247+
'PRE_LOAD_V': False
248248
},
249249
num_stages=1,
250-
num_warps=8,
251-
),
250+
num_warps=8),
252251
triton.Config(
253252
{
254-
"BLOCK_M": 128,
255-
"BLOCK_N": 64,
256-
"waves_per_eu": 1,
257-
"PRE_LOAD_V": False,
253+
'BLOCK_M': 128,
254+
'BLOCK_N': 64,
255+
'waves_per_eu': 1,
256+
'PRE_LOAD_V': False
258257
},
259258
num_stages=1,
260-
num_warps=4,
261-
),
259+
num_warps=4),
262260
triton.Config(
263261
{
264-
"BLOCK_M": 128,
265-
"BLOCK_N": 64,
266-
"waves_per_eu": 3,
267-
"PRE_LOAD_V": True,
262+
'BLOCK_M': 128,
263+
'BLOCK_N': 64,
264+
'waves_per_eu': 3,
265+
'PRE_LOAD_V': True
268266
},
269267
num_stages=1,
270-
num_warps=4,
271-
),
268+
num_warps=4),
272269
triton.Config(
273270
{
274-
"BLOCK_M": 128,
275-
"BLOCK_N": 64,
276-
"waves_per_eu": 3,
277-
"PRE_LOAD_V": False,
271+
'BLOCK_M': 128,
272+
'BLOCK_N': 64,
273+
'waves_per_eu': 3,
274+
'PRE_LOAD_V': False
278275
},
279276
num_stages=1,
280-
num_warps=4,
281-
),
277+
num_warps=4),
282278
triton.Config(
283279
{
284-
"BLOCK_M": 64,
285-
"BLOCK_N": 64,
286-
"waves_per_eu": 4,
287-
"PRE_LOAD_V": False,
280+
'BLOCK_M': 64,
281+
'BLOCK_N': 64,
282+
'waves_per_eu': 4,
283+
'PRE_LOAD_V': False
288284
},
289285
num_stages=1,
290-
num_warps=8,
291-
),
286+
num_warps=8),
292287
triton.Config(
293288
{
294-
"BLOCK_M": 32,
295-
"BLOCK_N": 32,
296-
"waves_per_eu": 4,
297-
"PRE_LOAD_V": False,
289+
'BLOCK_M': 32,
290+
'BLOCK_N': 32,
291+
'waves_per_eu': 4,
292+
'PRE_LOAD_V': False
298293
},
299294
num_stages=1,
300-
num_warps=8,
301-
),
295+
num_warps=8),
302296
# TODO: This config fails with head_size not pow2 with data mismatches.
303297
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
304298
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
@@ -314,8 +308,92 @@ def _attn_fwd_inner(
314308
# num_stages=1,
315309
# num_warps=4,
316310
# ),
317-
],
318-
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'],
311+
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'],
312+
313+
314+
def get_rdna_autotune_configs():
315+
return [
316+
triton.Config(
317+
{
318+
'BLOCK_M': 32,
319+
'BLOCK_N': 32,
320+
'waves_per_eu': 4,
321+
'PRE_LOAD_V': False
322+
},
323+
num_stages=1,
324+
num_warps=2),
325+
triton.Config(
326+
{
327+
'BLOCK_M': 32,
328+
'BLOCK_N': 32,
329+
'waves_per_eu': 2,
330+
'PRE_LOAD_V': False
331+
},
332+
num_stages=1,
333+
num_warps=2),
334+
triton.Config(
335+
{
336+
'BLOCK_M': 32,
337+
'BLOCK_N': 16,
338+
'waves_per_eu': 4,
339+
'PRE_LOAD_V': False
340+
},
341+
num_stages=1,
342+
num_warps=2),
343+
triton.Config(
344+
{
345+
'BLOCK_M': 32,
346+
'BLOCK_N': 16,
347+
'waves_per_eu': 2,
348+
'PRE_LOAD_V': False
349+
},
350+
num_stages=1,
351+
num_warps=2),
352+
triton.Config(
353+
{
354+
'BLOCK_M': 16,
355+
'BLOCK_N': 16,
356+
'waves_per_eu': 4,
357+
'PRE_LOAD_V': False
358+
},
359+
num_stages=1,
360+
num_warps=2),
361+
triton.Config(
362+
{
363+
'BLOCK_M': 16,
364+
'BLOCK_N': 16,
365+
'waves_per_eu': 2,
366+
'PRE_LOAD_V': False
367+
},
368+
num_stages=1,
369+
num_warps=2),
370+
# Fall-back config.
371+
triton.Config(
372+
{
373+
'BLOCK_M': 16,
374+
'BLOCK_N': 16,
375+
'waves_per_eu': 1,
376+
'PRE_LOAD_V': False
377+
},
378+
num_stages=1,
379+
num_warps=2),
380+
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']
381+
382+
383+
def get_autotune_configs():
384+
if is_navi():
385+
return get_rdna_autotune_configs()
386+
else:
387+
return get_cdna_autotune_configs()
388+
389+
390+
autotune_configs, autotune_keys = get_autotune_configs()
391+
392+
393+
@triton.autotune(
394+
configs=autotune_configs,
395+
key=autotune_keys,
396+
use_cuda_graph=True,
319397
)
320398
@triton.jit
321399
def attn_fwd(
@@ -833,6 +911,10 @@ def check_and_convert(t, scale):
833911
p_descale = 1.0 / p_scale
834912
o_descale = 1.0 / o_scale
835913

914+
if is_navi():
915+
max_seqlens_q = 0
916+
max_seqlens_k = 0
917+
836918
attn_fwd[grid](
837919
q,
838920
k,

0 commit comments

Comments
 (0)