From ab33e0f99dca34b291d20849f56c5fa7d2b2da9e Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 25 Oct 2024 13:29:31 -0400 Subject: [PATCH 1/2] Fix kernel cache miss and add RDNA configs - added Navi configurations (Related PR: https://github.com/ROCm/triton/pull/640) - resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0 --- vllm/attention/ops/triton_flash_attention.py | 186 +++++++++++++------ 1 file changed, 134 insertions(+), 52 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 3d53cd4b5700f..c151234812a68 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -24,6 +24,8 @@ import triton import triton.language as tl +from vllm.utils import is_navi + torch_dtype: tl.constexpr = torch.float16 @@ -217,88 +219,80 @@ def _attn_fwd_inner( return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), @@ -314,8 +308,92 @@ def _attn_fwd_inner( # num_stages=1, # num_warps=4, # ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'], + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'], + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'] + + +def get_autotune_configs(): + if is_navi(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -833,6 +911,10 @@ def check_and_convert(t, scale): p_descale = 1.0 / p_scale o_descale = 1.0 / o_scale + if is_navi(): + max_seqlens_q = 0 + max_seqlens_k = 0 + attn_fwd[grid]( q, k, From 5022d5c2a1d8b459c42984cf91ad54b91d3fc188 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Thu, 5 Dec 2024 20:33:08 -0500 Subject: [PATCH 2/2] Remove Navi autotune configs for triton FP8 support --- vllm/attention/ops/triton_flash_attention.py | 61 ++++++++++---------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index c151234812a68..a49df831b46ea 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -308,7 +308,7 @@ def get_cdna_autotune_configs(): # num_stages=1, # num_warps=4, # ), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'], + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] def get_rdna_autotune_configs(): @@ -349,35 +349,36 @@ def get_rdna_autotune_configs(): }, num_stages=1, num_warps=2), - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, - num_stages=1, - num_warps=2), - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, - num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config( - { - 'BLOCK_M': 16, - 'BLOCK_N': 16, - 'waves_per_eu': 1, - 'PRE_LOAD_V': False - }, - num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'] + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] def get_autotune_configs():