forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_cffi.py
37 lines (29 loc) · 1 KB
/
build_cffi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# This file contains the cffi-extension call to build the custom
# kernel used by amp.
# For mysterious reasons, it needs to live at the top-level directory.
# TODO: remove this when we move to cpp-extension.
import os
import torch
from torch.utils.ffi import create_extension
assert torch.cuda.is_available()
abs_path = os.path.dirname(os.path.realpath(__file__))
sources = ['apex/amp/src/scale_cuda.c']
headers = ['apex/amp/src/scale_cuda.h']
defines = [('WITH_CUDA', None)]
with_cuda = True
extra_objects = [os.path.join(abs_path, 'build/scale_kernel.o')]
# When running `python build_cffi.py` directly, set package=False. But
# if it's used with `cffi_modules` in setup.py, then set package=True.
package = (__name__ != '__main__')
extension = create_extension(
'apex.amp._C.scale_lib',
package=package,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)
if __name__ == '__main__':
extension.build()