diff --git a/CMakeLists.txt b/CMakeLists.txt index 759c87f2e..4c3e61d85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,7 +231,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) # FLASHATTENTION_DISABLE_LOCAL FLASHATTENTION_DISABLE_PYBIND FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8 - FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size ) elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0) message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.") diff --git a/csrc/flash_attn/flash_api_torch_lib.cpp b/csrc/flash_attn/flash_api_torch_lib.cpp index 5c3f7af5d..4156d60b5 100644 --- a/csrc/flash_attn/flash_api_torch_lib.cpp +++ b/csrc/flash_attn/flash_api_torch_lib.cpp @@ -14,6 +14,21 @@ namespace FLASH_NAMESPACE { ////////////////////////////// From flash_api.cpp ////////////////////////////// +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. @@ -105,6 +120,12 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ * Torch Library Registration */ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, " + "float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, " + "float softcap, bool return_softmax, Generator? gen)" + "-> Tensor[]"); + ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); + ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, " "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, " "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 4701fa202..a50ee24bc 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -117,14 +117,6 @@ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() -#elif defined(FLASHATTENTION_VARLEN_ONLY) - #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - TORCH_CHECK(COND, "This flash attention build only supports varlen " \ - "(for build size reasons)."); \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - }() #else #define VARLEN_SWITCH BOOL_SWITCH #endif diff --git a/tests/test_vllm_flash_attn.py b/tests/test_vllm_flash_attn.py index a49ce4782..ae5f16636 100644 --- a/tests/test_vllm_flash_attn.py +++ b/tests/test_vllm_flash_attn.py @@ -11,9 +11,11 @@ from einops import rearrange, repeat from vllm_flash_attn.flash_attn_interface import ( + flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache, - is_fa_version_supported + is_fa_version_supported, + fa_version_unsupported_reason ) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -23,15 +25,49 @@ # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] -VERSIONS = \ - ([2] if is_fa_version_supported(2) else []) + \ - ([3] if is_fa_version_supported(3) else []) +VERSIONS = [2, 3] + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) def ref_attn( q, k, v, + scale, query_padding_mask=None, key_padding_mask=None, attn_bias=None, @@ -74,10 +110,11 @@ def ref_attn( k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] + q *= scale if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores = torch.einsum("bthd,bshd->bhts", q, k) else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + scores = torch.einsum("bthd,bshd->bhts", q, k) lse_ref = scores.logsumexp(dim=-1) @@ -178,6 +215,59 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.parametrize("seq_len", [1, 10, 256, 533]) +@pytest.mark.parametrize("batch_size", [1, 7, 32]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("fa_version", VERSIONS) +@torch.inference_mode() +def test_flash_attn( + seq_len: int, + batch_size: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + soft_cap: Optional[float], + fa_version: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + scale = head_size**-0.5 + + query = torch.randn( + batch_size, seq_len, num_query_heads, head_size, dtype=dtype) + key = torch.randn( + batch_size, seq_len, num_kv_heads, head_size, dtype=dtype) + value = torch.randn( + batch_size, seq_len, num_kv_heads, head_size, dtype=dtype) + + output = flash_attn_func( + query, + key, + value, + softmax_scale=scale, + causal=True, + softcap=soft_cap if soft_cap is not None else 0, + fa_version=fa_version, + ) + + ref_output, _ = ref_attn( + q=query, + k=key, + v=value, + scale=scale, + causal=True, + softcap=soft_cap if soft_cap is not None else 0, + ) + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) diff --git a/vllm_flash_attn/__init__.py b/vllm_flash_attn/__init__.py index f2013a199..2f73f9dde 100644 --- a/vllm_flash_attn/__init__.py +++ b/vllm_flash_attn/__init__.py @@ -2,6 +2,7 @@ # Use relative import to support build-from-source installation in vLLM from .flash_attn_interface import ( + flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache, sparse_attn_func, diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 81e2c22e5..8c712a68b 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -73,6 +73,115 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, + fa_version: int = DEFAULT_FA_VERSION, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if fa_version == 2: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops._vllm_fa2_C.fwd( + q, + k, + v, + out, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size[0], window_size[1], + softcap, + return_softmax_lse and dropout_p > 0, + None, + ) + elif fa_version == 3: + out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( + q, k, v, + None, None, # k_new, v_new + out, + None, None, # cu_seqlens_q, cu_seqlens_k + None, # cu_seqlens_k_new + None, None, # seqused_q, seqused_k + None, None, # max_seqlen_q, max_seqlen_k + None, + alibi_slopes, + None, # kv_batch_idx + None, None, # rotary_cos, rotary_sin + None, None, None, # q_descale, k_descale, v_descale + softmax_scale, + causal, + window_size[0], window_size[1], + 0, # sink_token_length + softcap, + True, # rotary_interleaved + 0, # num_splits + None, # pack_gqa + 0, # sm_margin + ) + + return (out, softmax_lse) if return_softmax_lse else out + + def flash_attn_varlen_func( q, k,