Skip to content

Commit 938aec9

Browse files
committed
Update README and README_zh to reflect the new Flash-Sparse-Attention branding and features
1 parent b077f0c commit 938aec9

File tree

2 files changed

+118
-282
lines changed

2 files changed

+118
-282
lines changed

README.md

Lines changed: 55 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,54 @@
1010
</div>
1111

1212

13-
![Flash-DMA Banner](assets/flash_dmattn_banner.png)
13+
![Flash-Sparse-Attention Banner](assets/flash_sparse_attention_banner.png)
1414

15-
Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
15+
Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
16+
17+
18+
## Why Flash-Sparse-Attention
19+
20+
In large-scale Transformer training and inference, the dominant bottlenecks diverge:
21+
22+
- **Training-side compute bottleneck**: The computational complexity of full attention grows quadratically with sequence length, and backpropagation requires repeating computations of the same order, leading to massive compute consumption on key-value pairs that contribute very little.
23+
- **Inference-side memory bottleneck**: Full attention requires repeated reading and writing of Q, K, V, and intermediate variables, making memory access to the KV-cache the dominant factor in the computation flow, hindering full utilization of compute resources.
24+
25+
Thus, a more effective approach is sparse attention: interacting each query with only the $w$ most relevant keys, reducing computation and memory access from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$. If the sparse pattern can adapt to the task, it has the potential to be both fast and accurate, addressing bottlenecks in both training and inference. For more details, please refer to the paper [Trainable Dynamic Mask Sparse Attention](https://arxiv.org/abs/2508.02124).
1626

1727

1828
## Key Features
1929

20-
### 🎯 Core Kernel Advantages
21-
- **Mask & Bias Support**: Native support for `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped attention mask and attention bias tensors
22-
- **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks
23-
- **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training
30+
### Supported Features
31+
32+
- Forward and backward passes with causal mask
33+
- Arbitrary Q and KV sequence lengths
34+
- Arbitrary number of heads and head dimensions up to 256
35+
- Grouped Query Attention and Multi Query Attention
36+
- Flexible Mask and Bias
37+
- Skipping memory access and computation for masked regions
38+
- Gradient computation for bias
39+
40+
### Features We Aim to Support
2441

25-
### 🚀 Performance & Efficiency
26-
- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse structures
27-
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix
28-
- **CUDA Deep Optimization**: Custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead
29-
- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy
42+
- Paged Attention
43+
- TMA, WGMMA, and FP8 low-precision
44+
- Sequence Parallelism
45+
- Further performance improvements for skipping memory access and computation
3046

3147

3248
## Performance
3349

34-
We present the expected speedup of Flash-DMA over standard PyTorch SDPA under mask and bias conditions.
50+
We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
3551

36-
![Flash-DMA Performance Overview](assets/performance_overview.png)
52+
![FSA Performance Overview](assets/performance_overview.png)
3753

3854
---
3955

4056
### Forward Pass Performance
4157

42-
The following table shows the forward pass performance comparison between Flash-DMA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
58+
The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
4359

44-
| Mode | Q len | K len | Window W | SDPA (ms) | FDMA (ms) | Speedup |
60+
| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup |
4561
|--------|-------|--------|----------|-----------|-----------|---------|
4662
| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
4763
| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
@@ -91,9 +107,9 @@ The following table shows the forward pass performance comparison between Flash-
91107

92108
### Backward Pass Performance
93109

94-
The following table shows the backward pass performance comparison between Flash-DMA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
110+
The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
95111

96-
| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FDMA-BWD (ms) | Speedup |
112+
| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup |
97113
|-------|-------|--------|----------|---------------|---------------|---------|
98114
| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
99115
| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
@@ -131,17 +147,17 @@ The following table shows the backward pass performance comparison between Flash
131147

132148
### Install
133149

134-
You can install Flash-DMA via pre-compiled wheels:
150+
You can install FSA via pre-compiled wheels:
135151

136152
```bash
137-
pip install flash-dmattn --no-build-isolation
153+
pip install flash_sparse_attn --no-build-isolation
138154
```
139155

140156
Alternatively, you can compile and install from source:
141157

142158
```bash
143-
git clone https://github.com/SmallDoges/flash-dmattn.git
144-
cd flash-dmattn
159+
git clone https://github.com/SmallDoges/flash_sparse_attn.git
160+
cd flash_sparse_attn
145161
pip install . --no-build-isolation
146162
```
147163

@@ -152,8 +168,8 @@ pip install . --no-build-isolation
152168

153169
```python
154170
import torch
155-
from flash_dmattn import flash_dmattn_func_auto
156-
from flash_dmattn.utils.mask import create_mask
171+
from flash_sparse_attn import flash_sparse_attn_func_auto
172+
from flash_sparse_attn.utils.mask import create_mask
157173
import math
158174

159175
# Setup
@@ -169,7 +185,7 @@ key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dt
169185
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
170186

171187
# Create bias for sparse attention
172-
attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
188+
attn_bias = torch.randn(batch_size, num_kv_heads, 1, seq_len, device=device, dtype=dtype)
173189

174190
# Generate dynamic mask based on bias
175191
if seq_len > window_size:
@@ -183,11 +199,11 @@ if seq_len > window_size:
183199
min_dtype=min_dtype,
184200
)
185201

186-
# Select FDMA kernel
187-
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
202+
# Select FSA kernel
203+
flash_sparse_attn_func = flash_sparse_attn_func_auto(backend="cuda")
188204

189-
# Run Flash Dynamic Mask Attention
190-
output = flash_dmattn_func(
205+
# Run Flash-Sparse-Attention
206+
output = flash_sparse_attn_func(
191207
query=query,
192208
key=key,
193209
value=value,
@@ -210,7 +226,7 @@ value.requires_grad_(True)
210226
attn_bias.requires_grad_(True)
211227

212228
# Forward pass
213-
output = flash_dmattn_func(
229+
output = flash_sparse_attn_func(
214230
query=query, key=key, value=value,
215231
attn_mask=attn_mask,
216232
attn_bias=attn_bias,
@@ -229,67 +245,9 @@ print(f"Bias gradient shape: {attn_bias.grad.shape}")
229245
```
230246

231247

232-
## How It Works
233-
234-
Flash-DMA integrates the efficient memory access patterns of Flash Attention with the sparse computation capabilities of dynamic mask attention to achieve an efficient attention mechanism.
235-
236-
### Core Technology Integration
237-
238-
- **🎯 Native Mask & Bias Support**: Kernels directly process `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped tensors
239-
- **⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks
240-
- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation supporting end-to-end differentiable training
241-
242-
### Key Optimization Strategies
243-
244-
1. **Unified Skip Logic**: Forward and backward passes use the same block-level skip decisions
245-
2. **Memory Access Optimization**: K/V data loaded only when `OR(mask_block) == true`
246-
3. **Gradient Path Completeness**: dbias gradient computation fully fused in backward kernels
247-
4. **Shared Memory Reuse**: sMask ↔ sP, sBias ↔ sdS intelligent aliasing
248-
249-
250-
## Documentation
251-
252-
📚 **Complete documentation is available in the [docs](docs/) directory:**
253-
254-
- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples
255-
- **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration
256-
257-
258-
## Building from Source
259-
260-
### Development Setup
261-
262-
```bash
263-
# Clone with submodules
264-
git clone https://github.com/SmallDoges/flash-dmattn.git
265-
cd flash-dmattn
266-
267-
# Build in development mode
268-
pip install -e .
269-
270-
# Run tests to verify installation
271-
python -c "import flash_dma_cuda; print('✅ Flash DMA CUDA extension imported successfully')"
272-
```
273-
274-
### Build Requirements
275-
276-
- CUDA Toolkit 11.8+
277-
- CUTLASS library
278-
- PyTorch with CUDA support
279-
280-
### Supported Architectures
281-
282-
- **SM 8.0**
283-
- **SM 9.0**
284-
- **SM 10.0**
285-
- **SM 12.0**
286-
287-
**Note**: Flash Dynamic Mask Attention requires CUDA compute capability 8.0+ for optimal performance. Earlier architectures are not supported.
288-
289-
290248
## Benchmarking
291249

292-
Flash-DMA provides comprehensive benchmarking tools to evaluate performance across different configurations:
250+
FSA provides comprehensive benchmarking tools to evaluate performance across different configurations:
293251

294252
### Forward Pass Equivalence
295253
```bash
@@ -301,7 +259,7 @@ Validates numerical consistency between Python reference and CUDA implementation
301259
```bash
302260
python benchmarks/forward_performance.py
303261
```
304-
Compares Flash-DMA against standard SDPA across various sequence lengths and batch sizes.
262+
Compares FSA against standard SDPA across various sequence lengths and batch sizes.
305263

306264
### Backward Pass Equivalence
307265
```bash
@@ -313,7 +271,7 @@ Validates numerical consistency between Python reference and CUDA implementation
313271
```bash
314272
python benchmarks/backward_performance.py
315273
```
316-
Compares Flash-DMA against standard SDPA across various sequence lengths and batch sizes.
274+
Compares FSA against standard SDPA across various sequence lengths and batch sizes.
317275

318276
### Gradient Computation
319277
```bash
@@ -322,61 +280,21 @@ python benchmarks/grad_equivalence.py
322280
Tests backward pass implementation and gradient equivalence.
323281

324282

325-
## Troubleshooting
326-
327-
### Common Issues
328-
329-
**Compilation Errors**
330-
```bash
331-
# Ensure CUDA_HOME is set correctly
332-
echo $CUDA_HOME # Linux/Mac
333-
echo $env:CUDA_HOME # Windows PowerShell
334-
335-
# Check CUDA toolkit version
336-
nvcc --version
337-
338-
# Verify PyTorch CUDA support
339-
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
340-
```
341-
342-
**Import Errors**
343-
```python
344-
# Test basic import
345-
try:
346-
from flash_dmattn import flash_dmattn_func, get_available_backends
347-
print("✅ Flash Dynamic Mask Attention imported successfully")
348-
print(f"Available backends: {get_available_backends()}")
349-
except ImportError as e:
350-
print(f"❌ Import failed: {e}")
351-
print("Please ensure the package is properly installed with: pip install -e .")
352-
```
353-
354-
**Performance Issues**
355-
```python
356-
# Monitor GPU memory usage
357-
from flash_dmattn import flash_dmattn_func
358-
359-
def print_memory_stats():
360-
if torch.cuda.is_available():
361-
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
283+
## Documentation
362284

363-
print_memory_stats()
364-
output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True)
365-
print_memory_stats()
285+
📚 **Complete documentation is available in the [docs](docs/) directory:**
366286

367-
# Clear cache if needed
368-
torch.cuda.empty_cache()
369-
```
287+
- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples
370288

371289

372290
## Contributing
373291

374-
We welcome contributions from the community! Flash-DMA is an open-source project and we value all types of contributions.
292+
We welcome contributions from the community! FSA is an open-source project and we value all types of contributions.
375293

376294
### How to Contribute
377295

378-
- **Report bugs**: Found a bug? Please [open an issue](https://github.com/SmallDoges/flash-dmattn/issues/new/choose)
379-
- **Request features**: Have an idea for improvement? [Let us know](https://github.com/SmallDoges/flash-dmattn/issues/new/choose)
296+
- **Report bugs**: Found a bug? Please [open an issue](https://github.com/SmallDoges/flash_sparse_attn/issues/new/choose)
297+
- **Request features**: Have an idea for improvement? [Let us know](https://github.com/SmallDoges/flash_sparse_attn/issues/new/choose)
380298
- **Submit code**: Ready to contribute code? Check our [Contributing Guide](CONTRIBUTING.md)
381299
- **Improve docs**: Help us make the documentation better
382300

@@ -401,7 +319,7 @@ This project is licensed under the BSD 3-Clause License. See [LICENSE](LICENSE)
401319

402320
## Citation
403321

404-
If you use Flash-DMA in your research, please cite:
322+
If you use FSA in your research, please cite:
405323

406324
```bibtex
407325
@misc{shi2025trainabledynamicmasksparse,

0 commit comments

Comments
 (0)