Skip to content

Commit

Permalink
Fix kernel cache miss and add RDNA configs
Browse files Browse the repository at this point in the history
- added Navi configurations (Related PR: ROCm/triton#640)
- resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0
  • Loading branch information
hyoon1 committed Nov 27, 2024
1 parent 529cefe commit ce52a5e
Showing 1 changed file with 139 additions and 58 deletions.
197 changes: 139 additions & 58 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import triton
import triton.language as tl

from vllm.utils import is_navi

torch_dtype: tl.constexpr = torch.float16


Expand Down Expand Up @@ -207,103 +209,178 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_cdna_autotune_configs():
return [
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
'BLOCK_M': 256,
'BLOCK_N': 128,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': True
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
num_warps=4),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 64,
'BLOCK_N': 64,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8,
),
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_rdna_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 32,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 32,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 4,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_autotune_configs():
if is_navi():
return get_rdna_autotune_configs()
else:
return get_cdna_autotune_configs()


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Expand Down Expand Up @@ -773,6 +850,10 @@ def forward(
else:
bias_strides = (0, 0, 0, 0)

if is_navi():
max_seqlens_q = 0
max_seqlens_k = 0

attn_fwd[grid](
q,
k,
Expand Down

0 comments on commit ce52a5e

Please sign in to comment.