Skip to content

Commit

Permalink
PyTorch 2.0 Support (#21)
Browse files Browse the repository at this point in the history
* PyTorch 2.0 wheels!

* Almost forgot
  • Loading branch information
alihassanijr authored Mar 16, 2023
1 parent ad39be7 commit 3f43acd
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: clean uninstall install test style quality
.PHONY: clean uninstall install install-cutlass test style quality

check_dirs := src tests

Expand Down
7 changes: 6 additions & 1 deletion dev/packaging/build_all_wheels_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ build_one() {
echo "Launching container $container_name ..."
container_id="$container_name"_"$cu"_"$pytorch_ver"

if [ $cp310 -eq 2 ]; then
if [ $cp310 -eq 3 ]; then
py_versions=(3.8 3.9 3.10 3.11)
elif [ $cp310 -eq 2 ]; then
py_versions=(3.7 3.8 3.9 3.10 3.11)
elif [ $cp310 -eq 1 ]; then
py_versions=(3.7 3.8 3.9 3.10)
Expand Down Expand Up @@ -57,6 +59,9 @@ EOF
if [[ -n "$1" ]] && [[ -n "$2" ]]; then
build_one "$1" "$2"
else
# 2.0 and newer -- build 3.8 <= python <= 3.11 wheels
build_one cu118 2.0.0 3 & build_one cu117 2.0.0 3 & build_one cpu 2.0.0 3

# 1.13 and newer -- build python 3.11 wheels
build_one cu117 1.13 2 & build_one cu116 1.13 2 & build_one cpu 1.13 2

Expand Down
2 changes: 1 addition & 1 deletion dev/packaging/gen_wheel_index.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export LC_ALL=C # reproducible sort
index=$root/index.html

cd "$root"
for cu in cpu cu101 cu102 cu111 cu113 cu115 cu116 cu117; do
for cu in cpu cu101 cu102 cu111 cu113 cu115 cu116 cu117 cu118; do
mkdir -p "$root/$cu"
cd "$root/$cu"
echo "Creating $PWD/index.html ..."
Expand Down
4 changes: 4 additions & 0 deletions dev/packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ setup_cuda() {
# and https://github.com/pytorch/vision/blob/main/packaging/pkg_helpers.bash for reference.
export FORCE_CUDA=1
case "$CU_VERSION" in
cu118)
export CUDA_HOME=/usr/local/cuda-11.8/
export TORCH_CUDA_ARCH_LIST="6.0;6.1+PTX;7.0;7.5+PTX;8.0;8.6+PTX;8.9+PTX;9.0+PTX"
;;
cu117)
export CUDA_HOME=/usr/local/cuda-11.7/
export TORCH_CUDA_ARCH_LIST="6.0;6.1+PTX;7.0;7.5+PTX;8.0;8.6+PTX"
Expand Down
41 changes: 27 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
#!/usr/bin/env python
"""
Neighborhood Attention Extension (NATTEN)
Setup file
Heavily borrowed from detectron2 setup:
github.com/facebookresearch/detectron2
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
#################################################################################################
# Copyright (c) 2023 Ali Hassani.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
#################################################################################################

import warnings
import glob
Expand All @@ -20,8 +31,8 @@
from typing import List
import torch
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension

from pathlib import Path

this_directory = Path(__file__).parent
try:
long_description = (this_directory / "assets/README_pypi.md").read_text()
Expand Down Expand Up @@ -62,19 +73,21 @@ def get_extension():
main_source = path.join(extensions_dir, "natten.cpp")
sources_cpu = glob.glob(path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(path.join(extensions_dir, "cuda", "*.cu"))
sources = [main_source] + sources_cpu
sources_base = [main_source] + sources_cpu
sources = sources_base.copy()

from torch.utils.cpp_extension import ROCM_HOME

is_rocm_pytorch = (
True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
)
assert not is_rocm_pytorch, "Unfortunately NATTEN does not support ROCM."
assert not is_rocm_pytorch, "NATTEN does not support ROCM."

extension = CppExtension
extra_compile_args = {"cxx": ["-O3"]}
define_macros = []
if TORCH_113:
# Torch 1.13 and above have a new dispatcher template
define_macros += [("TORCH_113", 1)]
if AVX_INT:
define_macros += [("AVX_INT", 1)]
Expand Down
2 changes: 1 addition & 1 deletion src/natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .natten1d import NeighborhoodAttention1D
from .natten2d import NeighborhoodAttention2D

__version__ = "0.14.5.dev0"
__version__ = "0.14.5"

0 comments on commit 3f43acd

Please sign in to comment.