You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
15
+
Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
16
+
17
+
18
+
## Why Flash-Sparse-Attention
19
+
20
+
In large-scale Transformer training and inference, the dominant bottlenecks diverge:
21
+
22
+
-**Training-side compute bottleneck**: The computational complexity of full attention grows quadratically with sequence length, and backpropagation requires repeating computations of the same order, leading to massive compute consumption on key-value pairs that contribute very little.
23
+
-**Inference-side memory bottleneck**: Full attention requires repeated reading and writing of Q, K, V, and intermediate variables, making memory access to the KV-cache the dominant factor in the computation flow, hindering full utilization of compute resources.
24
+
25
+
Thus, a more effective approach is sparse attention: interacting each query with only the $w$ most relevant keys, reducing computation and memory access from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$. If the sparse pattern can adapt to the task, it has the potential to be both fast and accurate, addressing bottlenecks in both training and inference. For more details, please refer to the paper [Trainable Dynamic Mask Sparse Attention](https://arxiv.org/abs/2508.02124).
16
26
17
27
18
28
## Key Features
19
29
20
-
### 🎯 Core Kernel Advantages
21
-
-**Mask & Bias Support**: Native support for `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped attention mask and attention bias tensors
22
-
-**Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks
23
-
-**Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training
30
+
### Supported Features
31
+
32
+
- Forward and backward passes with causal mask
33
+
- Arbitrary Q and KV sequence lengths
34
+
- Arbitrary number of heads and head dimensions up to 256
35
+
- Grouped Query Attention and Multi Query Attention
36
+
- Flexible Mask and Bias
37
+
- Skipping memory access and computation for masked regions
38
+
- Gradient computation for bias
39
+
40
+
### Features We Aim to Support
24
41
25
-
### 🚀 Performance & Efficiency
26
-
-**Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse structures
27
-
-**Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix
28
-
-**CUDA Deep Optimization**: Custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead
29
-
-**Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy
42
+
- Paged Attention
43
+
- TMA, WGMMA, and FP8 low-precision
44
+
- Sequence Parallelism
45
+
- Further performance improvements for skipping memory access and computation
30
46
31
47
32
48
## Performance
33
49
34
-
We present the expected speedup of Flash-DMA over standard PyTorch SDPA under mask and bias conditions.
50
+
We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
The following table shows the forward pass performance comparison between Flash-DMA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
58
+
The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
43
59
44
-
| Mode | Q len | K len | Window W | SDPA (ms) |FDMA (ms) | Speedup |
60
+
| Mode | Q len | K len | Window W | SDPA (ms) |FSA (ms) | Speedup |
@@ -91,9 +107,9 @@ The following table shows the forward pass performance comparison between Flash-
91
107
92
108
### Backward Pass Performance
93
109
94
-
The following table shows the backward pass performance comparison between Flash-DMA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
110
+
The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
95
111
96
-
| Mode | Q len | K len | Window W | SDPA-BWD (ms) |FDMA-BWD (ms) | Speedup |
112
+
| Mode | Q len | K len | Window W | SDPA-BWD (ms) |FSA-BWD (ms) | Speedup |
Flash-DMA integrates the efficient memory access patterns of Flash Attention with the sparse computation capabilities of dynamic mask attention to achieve an efficient attention mechanism.
-**⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks
0 commit comments