Skip to content

Commit

Permalink
Merge pull request #1 from alihassanijr/main
Browse files Browse the repository at this point in the history
CPU kernels, refactored code, new release
  • Loading branch information
alihassanijr authored Oct 16, 2022
2 parents bbb5fdd + db502ee commit 9d08906
Show file tree
Hide file tree
Showing 42 changed files with 2,250 additions and 728 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Changelog

## [0.14.2] - 2022-10-15

### Added
- CPU support!
- CPP backend for CPU computation.
- CPU-only builds now supported.
- Note we only have naive kernels for CPU at the moment. Feel free to open a PR!

### Changed
- Refactored the CPP/CUDA backend.
- Unit tests for NA1D and NA2D
- Gradcheck tests in slow and fast mode
- Gradcheck tests for CPU backend
- Allclose tests between CPU and CUDA outputs and gradients

## [0.14.1] - 2022-10-08

### Added
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,19 @@ cd NATTEN
pip install -e .
```

#### Optional: unit tests
You can optionally run unit tests to verify building from source finished successfully:
```bash
python -m unittest discover -v -s ./tests
```


## Catalog
- [x] Neighborhood Attention 1D (CUDA)
- [x] Neighborhood Attention 2D (CUDA)
- [ ] Neighborhood Attention 3D (CUDA)
- [ ] Neighborhood Attention 1D (CPU)
- [ ] Neighborhood Attention 2D (CPU)
- [x] Neighborhood Attention 1D (CPU)
- [x] Neighborhood Attention 2D (CPU)
- [ ] Neighborhood Attention 3D (CPU)
- [x] Dilation support
- [x] Float16 support and utilization
Expand Down
155 changes: 155 additions & 0 deletions assets/README_pypi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
![NATTENLogo](https://www.shi-labs.com/natten/assets/img/natten_light.png)

<a href="https://www.shi-labs.com/natten/"><img src="https://img.shields.io/badge/pip%20install%20natten-read%20more-%23C209C1" /></a>

*Neighborhood Attention Extension*

Bringing attention to a neighborhood near you!

NATTEN is an extension to PyTorch, which provides the first fast sliding window attention with efficient CUDA kernels.
It provides <a href="https://arxiv.org/abs/2204.07143">Neighborhood Attention</a> (local attention)
and <a href="https://arxiv.org/abs/2209.15001">Dilated Neighborhood Attention</a>
(sparse global attention, a.k.a. dilated local attention) as PyTorch modules for both 1D and 2D data.

## About NATTEN
Sliding window self attention mechanisms have been relatively overlooked, in part due to implementation difficulties.
For example, in a paper proposing one of the earliest examples of such methods,
[SASA](https://proceedings.neurips.cc/paper/2019/file/3416a75f4cea9109507cacd8e2f2aefc-Paper.pdf),
it was noted that
although such methods are theoretically efficient, they're relatively slow in practice, compared to convolutions,
which have been implemented in most well-known deep learning libraries.

That is why we started developing NATTEN, an extension to existing libraries with efficient implementations of sliding window
attention mechanisms, which will enable research in this direction including building powerful hierarchical vision
transformers.

For more information, we highly recommend reading our preprints [NAT](https://arxiv.org/abs/2204.07143) and
[DiNAT](https://arxiv.org/abs/2209.15001), and check out their [repository](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer).

### How fast is NATTEN?
The latest version of NATTEN runs pretty fast on Ampere with the latest torch and CUDA versions.

![TimePlot](https://www.shi-labs.com/natten/assets/img/cudatime_light.png)
![MemPlot](https://www.shi-labs.com/natten/assets/img/cudamemory_light.png)


## Requirements
NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, and 3.9.
However, we highly recommend using Python 3.8 and PyTorch 1.12.1 + CUDA 11.6 for the best performance.

**NOTE:** The current version of NATTEN comes with Linux-only wheels, and supports Pascal and above (`SM >= 60`, i.e. Tesla P100).
Make sure your GPU is supported by referring to
[this webpage](https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/).
Future versions will extend support to older GPUs.

You may try and build from source on Windows, but do so at your own risk.
We also welcome contributions in all forms.

## Getting started

### Linux
Just refer to our website, [shi-labs.com/natten](https://www.shi-labs.com/natten/), select your PyTorch version and the CUDA
version it was compiled with, copy-paste the command and install in seconds!

For example, if you're on `torch==1.12.1+cu116`, you should install NATTEN using the following wheel:
```bash
pip3 install natten -f https://shi-labs.com/natten/wheels/cu116/torch1.12.1/index.html
```

More generally:
```bash
pip3 install natten -f https://shi-labs.com/natten/wheels/{cu_version}/torch{torch_version}/index.html
```

**NOTE:** If you do not specify a wheel URL, you will install a "placeholder" version of NATTEN, which is not usable.
We strongly recommend using our website, or building from source.

### Windows
NATTEN should support Windows devices with CUDA, but does not yet have Windows wheels.
You can try and build NATTEN from source (see below).

### Build from source
Once you've set up your Python environment and installed PyTorch with CUDA, simply clone and build:

```bash
pip install ninja # Recommended, not required
git clone https://github.com/SHI-Labs/NATTEN
cd NATTEN
pip install -e .
```


## Catalog
- [x] Neighborhood Attention 1D (CUDA)
- [x] Neighborhood Attention 2D (CUDA)
- [ ] Neighborhood Attention 3D (CUDA)
- [x] Neighborhood Attention 1D (CPU)
- [x] Neighborhood Attention 2D (CPU)
- [ ] Neighborhood Attention 3D (CPU)
- [x] Dilation support
- [x] Float16 support and utilization
- [ ] BFloat16 support
- [ ] Kepler and Maxwell (30<=SM<60) support
- [ ] Windows builds

## Usage
Simply import `NeighborhoodAttention1D` or `NeighborhoodAttention2D` from `natten`:
```python
from natten import NeighborhoodAttention1D
from natten import NeighborhoodAttention2D

na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda()
na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4).cuda()
```

### FLOPs
We recommend counting flops through [fvcore](https://github.com/facebookresearch/fvcore).

```shell
pip install fvcore
```

Once you have fvcore installed, you can directly use our dedicated FLOP counter:
```python
from natten.flops import get_flops

flops = get_flops(model, input)
```

Alternatively, if you are using fvcore's `FlopCountAnalysis` directly, be sure to add our op handles:
```python
from fvcore.nn import FlopCountAnalysis
from natten.flops import add_natten_handle

# ...

flop_ctr = FlopCountAnalysis(model, input)
flop_ctr = add_natten_handle(flop_ctr)

# ...
```

## License
NATTEN is released under the [MIT License](https://github.com/SHI-Labs/NATTEN/blob/main/LICENSE).

## Citation
```bibtex
@article{hassani2022neighborhood,
title = {Neighborhood Attention Transformer},
author = {Ali Hassani and Steven Walton and Jiachen Li and Shen Li and Humphrey Shi},
year = 2022,
url = {https://arxiv.org/abs/2204.07143},
eprint = {2204.07143},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
@article{hassani2022dilated,
title = {Dilated Neighborhood Attention Transformer},
author = {Ali Hassani and Humphrey Shi},
year = 2022,
url = {https://arxiv.org/abs/2209.15001},
eprint = {2209.15001},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
```
77 changes: 0 additions & 77 deletions dev/packaging/build_all_wheels.sh

This file was deleted.

17 changes: 10 additions & 7 deletions dev/packaging/build_all_wheels_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ build_one() {
cu*)
container_name=manylinux-cuda${cu/cu/}
;;
cpu)
container_name=manylinux-cpu
;;
*)
echo "Unrecognized cu=$cu"
exit 1
Expand Down Expand Up @@ -48,17 +51,17 @@ EOF
if [[ -n "$1" ]] && [[ -n "$2" ]]; then
build_one "$1" "$2"
else
build_one cu116 1.12.1 & build_one cu113 1.12.1 & build_one cu102 1.12.1
build_one cu116 1.12.1 & build_one cu113 1.12.1 & build_one cu102 1.12.1 & build_one cpu 1.12.1

build_one cu116 1.12 & build_one cu113 1.12 & build_one cu102 1.12
build_one cu116 1.12 & build_one cu113 1.12 & build_one cu102 1.12 & build_one cpu 1.12

build_one cu115 1.11 & build_one cu113 1.11 & build_one cu102 1.11
build_one cu115 1.11 & build_one cu113 1.11 & build_one cu102 1.11 & build_one cpu 1.11

build_one cu113 1.10.1 & build_one cu111 1.10.1 & build_one cu102 1.10.1
build_one cu113 1.10.1 & build_one cu111 1.10.1 & build_one cu102 1.10.1 & build_one cpu 1.10.1

build_one cu113 1.10 & build_one cu111 1.10 & build_one cu102 1.10
build_one cu113 1.10 & build_one cu111 1.10 & build_one cu102 1.10 & build_one cpu 1.10

build_one cu111 1.9 & build_one cu102 1.9
build_one cu111 1.9 & build_one cu102 1.9 & build_one cpu 1.9

build_one cu111 1.8 & build_one cu102 1.8 & build_one cu101 1.8
build_one cu111 1.8 & build_one cu102 1.8 & build_one cu101 1.8 & build_one cpu 1.8
fi
2 changes: 1 addition & 1 deletion dev/packaging/build_default_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}

pytorch_ver="1.8"
container_name=manylinux-cuda101
container_name=manylinux-cpu
cu="cpu"
py_versions=(3.7 3.8 3.9)

Expand Down
5 changes: 3 additions & 2 deletions dev/packaging/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ echo "PYTORCH_VERSION: $PYTORCH_VERSION" # e.g. 1.4
setup_cuda
setup_wheel_python

yum install ninja-build -y
ln -sv /usr/bin/ninja-build /usr/bin/ninja || true
#yum install ninja-build -y
#ln -sv /usr/bin/ninja-build /usr/bin/ninja || true

pip_install pip numpy -U
pip_install ninja
pip_install "torch==$PYTORCH_VERSION" \
-f https://download.pytorch.org/whl/"$CU_VERSION"/torch_stable.html

Expand Down
2 changes: 1 addition & 1 deletion dev/packaging/gen_wheel_index.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export LC_ALL=C # reproducible sort
index=$root/index.html

cd "$root"
for cu in cu101 cu102 cu111 cu113 cu115 cu116; do
for cu in cpu cu101 cu102 cu111 cu113 cu115 cu116; do
mkdir -p "$root/$cu"
cd "$root/$cu"
echo "Creating $PWD/index.html ..."
Expand Down
1 change: 0 additions & 1 deletion dev/packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ setup_cuda() {
cpu)
unset FORCE_CUDA
export CUDA_VISIBLE_DEVICES=
echo "WARNING: NATTEN does not have a CPU build yet. This is a placeholder."
;;
*)
echo "Unrecognized CU_VERSION=$CU_VERSION"
Expand Down
8 changes: 3 additions & 5 deletions natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from .nattencuda1d import NeighborhoodAttention1D
from .nattencuda2d import NeighborhoodAttention2D
from .natten1d import NeighborhoodAttention1D
from .natten2d import NeighborhoodAttention2D, NeighborhoodAttention

from .nattencuda import NeighborhoodAttention

__version__ = "0.14.1"
__version__ = "0.14.2"
2 changes: 1 addition & 1 deletion natten/flops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Neighborhood Attention FLOP counter
Because we're using a custom CUDA kernel, FVCore won't recognize it and count its flops, so we have
Because we're using a custom CPP backend, FVCore won't recognize it and count flops, so we have
to manually define flop counters for each extension.
This source code is licensed under the license found in the
Expand Down
2 changes: 1 addition & 1 deletion natten/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
except ImportError:
raise ImportError(f"Failed to import NATTEN's CPP backend. " + \
f"This could be due to an invalid/incomplete install. " + \
f"Please uninstall NATTEN (pip uninstall natten) and re-install with the correct torch and cuda build: " + \
f"Please uninstall NATTEN (pip uninstall natten) and re-install with the correct torch build: " + \
f"natten.shi-labs.com."
)

Expand Down
2 changes: 1 addition & 1 deletion natten/nattencuda1d.py → natten/natten1d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Neighborhood Attention 1D PyTorch Module (CUDA only)
Neighborhood Attention 1D PyTorch Module
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
Expand Down
Loading

0 comments on commit 9d08906

Please sign in to comment.