Skip to content

Commit 8663822

Browse files
authored
Fix kernel cache miss and add RDNA configs (#246)
* Fix kernel cache miss and add RDNA configs - added Navi configurations (Related PR: ROCm/triton#640) - resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0 * Remove Navi autotune configs for triton FP8 support
1 parent 2b17421 commit 8663822

File tree

1 file changed

+135
-52
lines changed

1 file changed

+135
-52
lines changed

vllm/attention/ops/triton_flash_attention.py

Lines changed: 135 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import triton
2525
import triton.language as tl
2626

27+
from vllm.utils import is_navi
28+
2729
torch_dtype: tl.constexpr = torch.float16
2830

2931

@@ -217,88 +219,80 @@ def _attn_fwd_inner(
217219
return acc, l_i, m_i
218220

219221

220-
@triton.autotune(
221-
configs=[
222+
def get_cdna_autotune_configs():
223+
return [
222224
triton.Config(
223225
{
224-
"BLOCK_M": 256,
225-
"BLOCK_N": 64,
226-
"waves_per_eu": 2,
227-
"PRE_LOAD_V": False,
226+
'BLOCK_M': 256,
227+
'BLOCK_N': 64,
228+
'waves_per_eu': 2,
229+
'PRE_LOAD_V': False
228230
},
229231
num_stages=1,
230-
num_warps=8,
231-
),
232+
num_warps=8),
232233
triton.Config(
233234
{
234-
"BLOCK_M": 128,
235-
"BLOCK_N": 128,
236-
"waves_per_eu": 2,
237-
"PRE_LOAD_V": False,
235+
'BLOCK_M': 128,
236+
'BLOCK_N': 128,
237+
'waves_per_eu': 2,
238+
'PRE_LOAD_V': False
238239
},
239240
num_stages=1,
240-
num_warps=4,
241-
),
241+
num_warps=4),
242242
triton.Config(
243243
{
244-
"BLOCK_M": 256,
245-
"BLOCK_N": 128,
246-
"waves_per_eu": 2,
247-
"PRE_LOAD_V": False,
244+
'BLOCK_M': 256,
245+
'BLOCK_N': 128,
246+
'waves_per_eu': 2,
247+
'PRE_LOAD_V': False
248248
},
249249
num_stages=1,
250-
num_warps=8,
251-
),
250+
num_warps=8),
252251
triton.Config(
253252
{
254-
"BLOCK_M": 128,
255-
"BLOCK_N": 64,
256-
"waves_per_eu": 1,
257-
"PRE_LOAD_V": False,
253+
'BLOCK_M': 128,
254+
'BLOCK_N': 64,
255+
'waves_per_eu': 1,
256+
'PRE_LOAD_V': False
258257
},
259258
num_stages=1,
260-
num_warps=4,
261-
),
259+
num_warps=4),
262260
triton.Config(
263261
{
264-
"BLOCK_M": 128,
265-
"BLOCK_N": 64,
266-
"waves_per_eu": 3,
267-
"PRE_LOAD_V": True,
262+
'BLOCK_M': 128,
263+
'BLOCK_N': 64,
264+
'waves_per_eu': 3,
265+
'PRE_LOAD_V': True
268266
},
269267
num_stages=1,
270-
num_warps=4,
271-
),
268+
num_warps=4),
272269
triton.Config(
273270
{
274-
"BLOCK_M": 128,
275-
"BLOCK_N": 64,
276-
"waves_per_eu": 3,
277-
"PRE_LOAD_V": False,
271+
'BLOCK_M': 128,
272+
'BLOCK_N': 64,
273+
'waves_per_eu': 3,
274+
'PRE_LOAD_V': False
278275
},
279276
num_stages=1,
280-
num_warps=4,
281-
),
277+
num_warps=4),
282278
triton.Config(
283279
{
284-
"BLOCK_M": 64,
285-
"BLOCK_N": 64,
286-
"waves_per_eu": 4,
287-
"PRE_LOAD_V": False,
280+
'BLOCK_M': 64,
281+
'BLOCK_N': 64,
282+
'waves_per_eu': 4,
283+
'PRE_LOAD_V': False
288284
},
289285
num_stages=1,
290-
num_warps=8,
291-
),
286+
num_warps=8),
292287
triton.Config(
293288
{
294-
"BLOCK_M": 32,
295-
"BLOCK_N": 32,
296-
"waves_per_eu": 4,
297-
"PRE_LOAD_V": False,
289+
'BLOCK_M': 32,
290+
'BLOCK_N': 32,
291+
'waves_per_eu': 4,
292+
'PRE_LOAD_V': False
298293
},
299294
num_stages=1,
300-
num_warps=8,
301-
),
295+
num_warps=8),
302296
# TODO: This config fails with head_size not pow2 with data mismatches.
303297
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
304298
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
@@ -314,8 +308,93 @@ def _attn_fwd_inner(
314308
# num_stages=1,
315309
# num_warps=4,
316310
# ),
317-
],
318-
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'],
311+
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
312+
313+
314+
def get_rdna_autotune_configs():
315+
return [
316+
triton.Config(
317+
{
318+
'BLOCK_M': 32,
319+
'BLOCK_N': 32,
320+
'waves_per_eu': 4,
321+
'PRE_LOAD_V': False
322+
},
323+
num_stages=1,
324+
num_warps=2),
325+
triton.Config(
326+
{
327+
'BLOCK_M': 32,
328+
'BLOCK_N': 32,
329+
'waves_per_eu': 2,
330+
'PRE_LOAD_V': False
331+
},
332+
num_stages=1,
333+
num_warps=2),
334+
triton.Config(
335+
{
336+
'BLOCK_M': 32,
337+
'BLOCK_N': 16,
338+
'waves_per_eu': 4,
339+
'PRE_LOAD_V': False
340+
},
341+
num_stages=1,
342+
num_warps=2),
343+
triton.Config(
344+
{
345+
'BLOCK_M': 32,
346+
'BLOCK_N': 16,
347+
'waves_per_eu': 2,
348+
'PRE_LOAD_V': False
349+
},
350+
num_stages=1,
351+
num_warps=2),
352+
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
353+
# triton.Config(
354+
# {
355+
# 'BLOCK_M': 16,
356+
# 'BLOCK_N': 16,
357+
# 'waves_per_eu': 4,
358+
# 'PRE_LOAD_V': False
359+
# },
360+
# num_stages=1,
361+
# num_warps=2),
362+
# triton.Config(
363+
# {
364+
# 'BLOCK_M': 16,
365+
# 'BLOCK_N': 16,
366+
# 'waves_per_eu': 2,
367+
# 'PRE_LOAD_V': False
368+
# },
369+
# num_stages=1,
370+
# num_warps=2),
371+
# # Fall-back config.
372+
# triton.Config(
373+
# {
374+
# 'BLOCK_M': 16,
375+
# 'BLOCK_N': 16,
376+
# 'waves_per_eu': 1,
377+
# 'PRE_LOAD_V': False
378+
# },
379+
# num_stages=1,
380+
# num_warps=2),
381+
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
382+
383+
384+
def get_autotune_configs():
385+
if is_navi():
386+
return get_rdna_autotune_configs()
387+
else:
388+
return get_cdna_autotune_configs()
389+
390+
391+
autotune_configs, autotune_keys = get_autotune_configs()
392+
393+
394+
@triton.autotune(
395+
configs=autotune_configs,
396+
key=autotune_keys,
397+
use_cuda_graph=True,
319398
)
320399
@triton.jit
321400
def attn_fwd(
@@ -833,6 +912,10 @@ def check_and_convert(t, scale):
833912
p_descale = 1.0 / p_scale
834913
o_descale = 1.0 / o_scale
835914

915+
if is_navi():
916+
max_seqlens_q = 0
917+
max_seqlens_k = 0
918+
836919
attn_fwd[grid](
837920
q,
838921
k,

0 commit comments

Comments
 (0)