Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jpvillam/v0.3.3_triton' into int…
Browse files Browse the repository at this point in the history
…egration
  • Loading branch information
gshtras committed Mar 20, 2024
2 parents c45547b + d4cb905 commit 9cb2bbd
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install -e . \
&& pip3 install . \
&& cd ../..; \
fi

Expand Down
45 changes: 30 additions & 15 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
import os

logger = init_logger(__name__)

Expand All @@ -34,12 +35,12 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if _use_flash_attn():

if use_triton := _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
sliding_window,
use_triton == 2)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
Expand All @@ -60,23 +61,37 @@ def forward(


@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False

if torch.cuda.get_device_capability()[0] < 8:
def _use_flash_attn() -> int:
"""Returns if and which flash attention to use.
Returns:
int: 0 for none, 1 for default implementation, 2 for triton implementation.
"""
if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()):
# AMD GPUs can use flash_attn package or triton impl.
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return 0

if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
return 0

if is_hip() and torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs. "
"Using xformers backend.")
return 0

if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
return 0

logger.info("Using flash_attn backend.")
return True
logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.")
return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1
46 changes: 25 additions & 21 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional

# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from vllm.utils import is_hip
try:
from flash_attn import flash_attn_func
except ImportError:
if is_hip():
pass
else:
raise

from flash_attn import flash_attn_func
import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention
from vllm.model_executor.layers.attention.ops.flash_attention_triton import triton_attention


class FlashAttentionBackend:
Expand All @@ -29,6 +21,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
use_triton: Optional[bool] = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -46,6 +39,7 @@ def __init__(
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_cache_scaling_factor = 1.0
self.use_triton = use_triton

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -104,8 +98,8 @@ def forward(
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
if is_hip():
output, _ = attention(
if self.use_triton:
output, _ = triton_attention(
query,
key,
value,
Expand All @@ -115,15 +109,25 @@ def forward(
self.scale,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if is_hip():
#XXX: window_size and alibi_slopes not supported
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
Expand Down
33 changes: 15 additions & 18 deletions vllm/model_executor/layers/attention/ops/flash_attention_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def attn_fwd(
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
#tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
#l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# We store inf to LSE, not -inf because in the bwd pass, we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.
l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
tl.store(l_ptrs, l)
#l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
#tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here too?
return

Expand Down Expand Up @@ -417,17 +417,17 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
#l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# This is a > check because mask being 0 blocks the store.
l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
#overflow_size = end_m_idx - seqlen_q
#if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
#else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))

# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
Expand Down Expand Up @@ -494,8 +494,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):

encoded_softmax = None

M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32)

# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
Expand All @@ -507,7 +505,7 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
bias_strides = (0,0,0,0)

attn_fwd[grid](
q, k, v, bias, sm_scale, M, o,
q, k, v, bias, sm_scale, None, o,
*q_strides, *k_strides, *v_strides, *o_strides, *bias_strides,
None, None,
dropout_p=0.0,
Expand All @@ -526,7 +524,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
RETURN_ENCODED_SOFTMAX=False
)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
Expand All @@ -538,4 +535,4 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
ctx.return_encoded_softmax = False
return o, encoded_softmax

attention = _attention.apply
triton_attention = _attention.apply

0 comments on commit 9cb2bbd

Please sign in to comment.