Skip to content

Commit

Permalink
Fix kernel cache miss and add RDNA configs
Browse files Browse the repository at this point in the history
- added Navi configurations (Related PR: ROCm/triton#640)
- resolved cache miss issue during flash attention calls by fixing max_seqlen_q/k to 0
  • Loading branch information
hyoon1 committed Oct 25, 2024
1 parent 030374b commit c50a528
Showing 1 changed file with 85 additions and 98 deletions.
183 changes: 85 additions & 98 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import subprocess
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -206,105 +207,91 @@ def _attn_fwd_inner(
(0, BLOCK_N))
return acc, l_i, m_i

def get_gfx_version():
try:
# Run the rocminfo command
result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

Check failure on line 213 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:213:81: E501 Line too long (104 > 80)

Check failure on line 213 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:213:81: E501 Line too long (104 > 80)

Check failure on line 213 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:213:81: E501 Line too long (104 > 80)

Check failure on line 213 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:213:81: E501 Line too long (104 > 80)

Check failure on line 213 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:213:81: E501 Line too long (104 > 80)
output = result.stdout

# Parse the output to find the gfx version
for line in output.splitlines():
line = line.strip()
if line.startswith("Name: gfx"):
gfx_version = line.split("Name:")[1].strip()
return gfx_version
except Exception as e:
print(f"Error: {e}")
return None

def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',

Check failure on line 231 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:231:81: E501 Line too long (112 > 80)

Check failure on line 231 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:231:81: E501 Line too long (112 > 80)

Check failure on line 231 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:231:81: E501 Line too long (112 > 80)

Check failure on line 231 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:231:81: E501 Line too long (112 > 80)

Check failure on line 231 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:231:81: E501 Line too long (112 > 80)
'gfx90a', 'gfx908')

Check failure on line 232 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:232:81: E501 Line too long (102 > 80)

Check failure on line 232 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:232:81: E501 Line too long (102 > 80)

Check failure on line 232 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:232:81: E501 Line too long (102 > 80)

Check failure on line 232 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:232:81: E501 Line too long (102 > 80)

Check failure on line 232 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:232:81: E501 Line too long (102 > 80)


def is_rdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",

Check failure on line 236 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:236:81: E501 Line too long (115 > 80)

Check failure on line 236 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:236:81: E501 Line too long (115 > 80)

Check failure on line 236 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:236:81: E501 Line too long (115 > 80)

Check failure on line 236 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:236:81: E501 Line too long (115 > 80)

Check failure on line 236 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:236:81: E501 Line too long (115 > 80)
"gfx1102", "gfx1200", "gfx1201")

Check failure on line 237 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:237:81: E501 Line too long (115 > 80)

Check failure on line 237 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:237:81: E501 Line too long (115 > 80)

Check failure on line 237 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:237:81: E501 Line too long (115 > 80)

Check failure on line 237 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:237:81: E501 Line too long (115 > 80)

Check failure on line 237 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:237:81: E501 Line too long (115 > 80)


def get_cdna_autotune_configs():
return [
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,

Check failure on line 242 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:242:81: E501 Line too long (109 > 80)

Check failure on line 242 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:242:81: E501 Line too long (109 > 80)

Check failure on line 242 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:242:81: E501 Line too long (109 > 80)

Check failure on line 242 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:242:81: E501 Line too long (109 > 80)

Check failure on line 242 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:242:81: E501 Line too long (109 > 80)
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,

Check failure on line 244 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:244:81: E501 Line too long (108 > 80)

Check failure on line 244 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:244:81: E501 Line too long (108 > 80)

Check failure on line 244 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:244:81: E501 Line too long (108 > 80)

Check failure on line 244 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:244:81: E501 Line too long (108 > 80)

Check failure on line 244 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:244:81: E501 Line too long (108 > 80)
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,

Check failure on line 246 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:246:81: E501 Line too long (108 > 80)

Check failure on line 246 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:246:81: E501 Line too long (108 > 80)

Check failure on line 246 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:246:81: E501 Line too long (108 > 80)

Check failure on line 246 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:246:81: E501 Line too long (108 > 80)

Check failure on line 246 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:246:81: E501 Line too long (108 > 80)
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,

Check failure on line 248 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:248:81: E501 Line too long (108 > 80)

Check failure on line 248 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:248:81: E501 Line too long (108 > 80)

Check failure on line 248 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:248:81: E501 Line too long (108 > 80)

Check failure on line 248 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:248:81: E501 Line too long (108 > 80)

Check failure on line 248 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:248:81: E501 Line too long (108 > 80)
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,

Check failure on line 250 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:250:81: E501 Line too long (108 > 80)

Check failure on line 250 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:250:81: E501 Line too long (108 > 80)

Check failure on line 250 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:250:81: E501 Line too long (108 > 80)

Check failure on line 250 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:250:81: E501 Line too long (108 > 80)

Check failure on line 250 in vllm/attention/ops/triton_flash_attention.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/ops/triton_flash_attention.py:250:81: E501 Line too long (108 > 80)
num_warps=4),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_rdna_autotune_configs():
return [
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL']


def get_autotune_configs():
if is_rdna():
return get_rdna_autotune_configs()
elif is_cdna():
return get_cdna_autotune_configs()
else:
raise ValueError("Unknown Device Type")


autotune_configs, autotune_keys = get_autotune_configs()

@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)

@triton.jit
def attn_fwd(
Q,
Expand Down Expand Up @@ -795,8 +782,8 @@ def forward(
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
MAX_SEQLENS_Q=0,
MAX_SEQLENS_K=0,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
Expand Down

0 comments on commit c50a528

Please sign in to comment.