diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 3d53cd4b5700f..a49df831b46ea 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -24,6 +24,8 @@ import triton import triton.language as tl +from vllm.utils import is_navi + torch_dtype: tl.constexpr = torch.float16 @@ -217,88 +219,80 @@ def _attn_fwd_inner( return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), @@ -314,8 +308,93 @@ def _attn_fwd_inner( # num_stages=1, # num_warps=4, # ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'], + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_autotune_configs(): + if is_navi(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -833,6 +912,10 @@ def check_and_convert(t, scale): p_descale = 1.0 / p_scale o_descale = 1.0 / o_scale + if is_navi(): + max_seqlens_q = 0 + max_seqlens_k = 0 + attn_fwd[grid]( q, k,