Skip to content

Commit

Permalink
add back flash_attn_func api (and support FA3)
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 26, 2025
1 parent d4e0903 commit a14a552
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 15 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
Expand Down
21 changes: 21 additions & 0 deletions csrc/flash_attn/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ namespace FLASH_NAMESPACE {

////////////////////////////// From flash_api.cpp //////////////////////////////

std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Expand Down Expand Up @@ -105,6 +120,12 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
* Torch Library Registration
*/
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
Expand Down
8 changes: 0 additions & 8 deletions hopper/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,6 @@
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#elif defined(FLASHATTENTION_VARLEN_ONLY)
#define VARLEN_SWITCH(COND, CONST_NAME, ...) \
[&] { \
TORCH_CHECK(COND, "This flash attention build only supports varlen " \
"(for build size reasons)."); \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
}()
#else
#define VARLEN_SWITCH BOOL_SWITCH
#endif
Expand Down
102 changes: 96 additions & 6 deletions tests/test_vllm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from einops import rearrange, repeat

from vllm_flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported
is_fa_version_supported,
fa_version_unsupported_reason
)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
Expand All @@ -23,15 +25,49 @@
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
VERSIONS = \
([2] if is_fa_version_supported(2) else []) + \
([3] if is_fa_version_supported(3) else [])
VERSIONS = [2, 3]


def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
key_leftpad=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)


def ref_attn(
q,
k,
v,
scale,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
Expand Down Expand Up @@ -74,10 +110,11 @@ def ref_attn(
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
q *= scale
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores = torch.einsum("bthd,bshd->bhts", q, k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
scores = torch.einsum("bthd,bshd->bhts", q, k)

lse_ref = scores.logsumexp(dim=-1)

Expand Down Expand Up @@ -178,6 +215,59 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("seq_len", [1, 10, 256, 533])
@pytest.mark.parametrize("batch_size", [1, 7, 32])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("fa_version", VERSIONS)
@torch.inference_mode()
def test_flash_attn(
seq_len: int,
batch_size: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
soft_cap: Optional[float],
fa_version: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5

query = torch.randn(
batch_size, seq_len, num_query_heads, head_size, dtype=dtype)
key = torch.randn(
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)
value = torch.randn(
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)

output = flash_attn_func(
query,
key,
value,
softmax_scale=scale,
causal=True,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)

ref_output, _ = ref_attn(
q=query,
k=key,
v=value,
scale=scale,
causal=True,
softcap=soft_cap if soft_cap is not None else 0,
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down
1 change: 1 addition & 0 deletions vllm_flash_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Use relative import to support build-from-source installation in vLLM
from .flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
sparse_attn_func,
Expand Down
109 changes: 109 additions & 0 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,115 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
fa_version: int = DEFAULT_FA_VERSION,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

if fa_version == 2:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.fwd(
q,
k,
v,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size[0], window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
None,
)
elif fa_version == 3:
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q, k, v,
None, None, # k_new, v_new
out,
None, None, # cu_seqlens_q, cu_seqlens_k
None, # cu_seqlens_k_new
None, None, # seqused_q, seqused_k
None, None, # max_seqlen_q, max_seqlen_k
None,
alibi_slopes,
None, # kv_batch_idx
None, None, # rotary_cos, rotary_sin
None, None, None, # q_descale, k_descale, v_descale
softmax_scale,
causal,
window_size[0], window_size[1],
0, # sink_token_length
softcap,
True, # rotary_interleaved
0, # num_splits
None, # pack_gqa
0, # sm_margin
)

return (out, softmax_lse) if return_softmax_lse else out


def flash_attn_varlen_func(
q,
k,
Expand Down

0 comments on commit a14a552

Please sign in to comment.