8
8
from packaging .version import Version , parse
9
9
from setuptools import find_packages , setup
10
10
from torch .utils .cpp_extension import CUDA_HOME , BuildExtension , CUDAExtension
11
+ from wheel .bdist_wheel import bdist_wheel as _bdist_wheel
11
12
12
13
# PEP0440 compatible formatted version, see:
13
14
# https://www.python.org/dev/peps/pep-0440/
46
47
]
47
48
DEV_REQUIRES = INSTALL_REQUIRES + QUANLITY_REQUIRES
48
49
49
- MAIN_CUDA_VERSION = "12.1"
50
50
51
+ # ninja build does not work unless include_dirs are abs path
52
+ this_dir = os .path .dirname (os .path .abspath (__file__ ))
51
53
52
- def _is_cuda () -> bool :
53
- return torch .version .cuda is not None
54
+ PACKAGE_NAME = "minference"
54
55
56
+ BASE_WHEEL_URL = (
57
+ "https://github.com/microsoft/MInference/releases/download/{tag_name}/{wheel_name}"
58
+ )
55
59
56
- def get_nvcc_cuda_version () -> Version :
57
- """Get the CUDA version from nvcc.
60
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
61
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
62
+ FORCE_BUILD = os .getenv ("MINFERENCE_FORCE_BUILD" , "FALSE" ) == "TRUE"
63
+ SKIP_CUDA_BUILD = os .getenv ("MINFERENCE_SKIP_CUDA_BUILD" , "FALSE" ) == "TRUE"
64
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
65
+ FORCE_CXX11_ABI = os .getenv ("MINFERENCE_FORCE_CXX11_ABI" , "FALSE" ) == "TRUE"
66
+
67
+
68
+ def check_if_cuda_home_none (global_option : str ) -> None :
69
+ if CUDA_HOME is not None :
70
+ return
71
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
72
+ # in that case.
73
+ warnings .warn (
74
+ f"{ global_option } was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
75
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
76
+ "only images whose names contain 'devel' will provide nvcc."
77
+ )
58
78
59
- Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
60
- """
61
- assert CUDA_HOME is not None , "CUDA_HOME is not set"
62
- nvcc_output = subprocess .check_output (
63
- [CUDA_HOME + "/bin/nvcc" , "-V" ], universal_newlines = True
79
+
80
+ cmdclass = {}
81
+ ext_modules = []
82
+
83
+ if not SKIP_CUDA_BUILD :
84
+ print ("\n \n torch.__version__ = {}\n \n " .format (torch .__version__ ))
85
+ TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
86
+ TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
87
+
88
+ # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
89
+ # See https://github.com/pytorch/pytorch/pull/70650
90
+ generator_flag = []
91
+ torch_dir = torch .__path__ [0 ]
92
+ if os .path .exists (
93
+ os .path .join (torch_dir , "include" , "ATen" , "CUDAGeneratorImpl.h" )
94
+ ):
95
+ generator_flag = ["-DOLD_GENERATOR_PATH" ]
96
+
97
+ check_if_cuda_home_none ("minference" )
98
+
99
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
100
+ # torch._C._GLIBCXX_USE_CXX11_ABI
101
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
102
+ if FORCE_CXX11_ABI :
103
+ torch ._C ._GLIBCXX_USE_CXX11_ABI = True
104
+ ext_modules .append (
105
+ CUDAExtension (
106
+ name = "minference.cuda" ,
107
+ sources = [
108
+ os .path .join ("csrc" , "kernels.cpp" ),
109
+ os .path .join ("csrc" , "vertical_slash_index.cu" ),
110
+ ],
111
+ extra_compile_args = ["-std=c++17" , "-O3" ],
112
+ )
64
113
)
65
- output = nvcc_output .split ()
66
- release_idx = output .index ("release" ) + 1
67
- nvcc_cuda_version = parse (output [release_idx ].split ("," )[0 ])
68
- return nvcc_cuda_version
69
114
70
115
71
116
def get_minference_version () -> str :
72
117
version = VERSION ["VERSION" ]
73
118
74
- if _is_cuda ():
75
- cuda_version = str (get_nvcc_cuda_version ())
76
- if cuda_version != MAIN_CUDA_VERSION :
77
- cuda_version_str = cuda_version .replace ("." , "" )[:3 ]
78
- version += f"+cu{ cuda_version_str } "
119
+ local_version = os .environ .get ("MINFERENCE_LOCAL_VERSION" )
120
+ if local_version :
121
+ return f"{ version } +{ local_version } "
79
122
else :
80
- raise RuntimeError ( "Unknown runtime environment" )
123
+ return str ( version )
81
124
82
- return version
83
125
126
+ class CachedWheelsCommand (_bdist_wheel ):
127
+ """
128
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
129
+ find an existing wheel (which is currently the case for all flash attention installs). We use
130
+ the environment parameters to detect whether there is already a pre-built version of a compatible
131
+ wheel available and short-circuits the standard full build pipeline.
132
+ """
133
+
134
+ def run (self ):
135
+ return super ().run ()
136
+
137
+
138
+ class NinjaBuildExtension (BuildExtension ):
139
+ def __init__ (self , * args , ** kwargs ) -> None :
140
+ # do not override env MAX_JOBS if already exists
141
+ if not os .environ .get ("MAX_JOBS" ):
142
+ import psutil
143
+
144
+ # calculate the maximum allowed NUM_JOBS based on cores
145
+ max_num_jobs_cores = max (1 , os .cpu_count () // 2 )
146
+
147
+ # calculate the maximum allowed NUM_JOBS based on free memory
148
+ free_memory_gb = psutil .virtual_memory ().available / (
149
+ 1024 ** 3
150
+ ) # free memory in GB
151
+ max_num_jobs_memory = int (
152
+ free_memory_gb / 9
153
+ ) # each JOB peak memory cost is ~8-9GB when threads = 4
154
+
155
+ # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
156
+ max_jobs = max (1 , min (max_num_jobs_cores , max_num_jobs_memory ))
157
+ os .environ ["MAX_JOBS" ] = str (max_jobs )
158
+
159
+ super ().__init__ (* args , ** kwargs )
84
160
85
- ext_modules = [
86
- CUDAExtension (
87
- name = "minference.cuda" ,
88
- sources = [
89
- os .path .join ("csrc" , "kernels.cpp" ),
90
- os .path .join ("csrc" , "vertical_slash_index.cu" ),
91
- ],
92
- extra_compile_args = ["-std=c++17" , "-O3" ],
93
- )
94
- ]
95
161
96
162
setup (
97
163
name = "minference" ,
@@ -110,7 +176,6 @@ def get_minference_version() -> str:
110
176
"Programming Language :: Python :: 3" ,
111
177
"Topic :: Scientific/Engineering :: Artificial Intelligence" ,
112
178
],
113
- package_dir = {"" : "." },
114
179
packages = find_packages (
115
180
exclude = (
116
181
"csrc" ,
@@ -136,5 +201,9 @@ def get_minference_version() -> str:
136
201
python_requires = ">=3.8.0" ,
137
202
zip_safe = False ,
138
203
ext_modules = ext_modules ,
139
- cmdclass = {"build_ext" : BuildExtension },
204
+ cmdclass = {"bdist_wheel" : CachedWheelsCommand , "build_ext" : NinjaBuildExtension }
205
+ if ext_modules
206
+ else {
207
+ "bdist_wheel" : CachedWheelsCommand ,
208
+ },
140
209
)
0 commit comments