Skip to content

Commit 981a39a

Browse files
authored
feat(build): Implement Python wrapper build using setuptools. (#28)
* add setup.py. * fix setup.py.
1 parent 87b1f6d commit 981a39a

File tree

12 files changed

+255
-18
lines changed

12 files changed

+255
-18
lines changed

.gitignore

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,28 @@
44
*.ptx
55
*.cubin
66
*.fatbin
7+
8+
# Byte-compiled / optimized / DLL files
9+
__pycache__/
10+
*.py[cod]
11+
*$py.class
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
share/python-wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,26 @@ cd TileFusion && git submodule update --init --recursive
8888

8989
TileFusion requires a C++20 host compiler, CUDA 12.0 or later, and GCC version 10.0 or higher to support C++20 features.
9090

91+
### Build from Source
92+
93+
#### Using Makefile
94+
To build the project using the provided `Makefile`, simply run:
95+
```bash
96+
make
97+
```
98+
99+
#### Building the Python Wrapper
100+
101+
1. Build the wheel:
102+
```bash
103+
python3 setup.py build bdist_wheel
104+
```
105+
106+
2. Clean the build:
107+
```bash
108+
python3 setup.py clean
109+
```
110+
91111
### Unit Test
92112

93113
- **Run a single unit test**: `make unit_test UNIT_TEST=test_scatter_nd.py`

cmake/generic.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ set(CMAKE_BUILD_TYPE Release)
77

88
set(CMAKE_CXX_STANDARD
99
20
10-
CACHE STRING "The C++ standard whoese features are requested." FORCE)
10+
CACHE STRING "The C++ standard whose features are requested." FORCE)
1111
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1212

1313
set(CMAKE_CUDA_STANDARD
@@ -48,6 +48,12 @@ set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -std=c++20)
4848
set(CUDA_NVCC_FLAGS_DEBUG ${CUDA_NVCC_FLAGS_DEBUG} -std=c++20 -O0)
4949
set(CUDA_NVCC_FLAGS_RELEASE ${CUDA_NVCC_FLAGS_RELEASE} -std=c++20 -O3)
5050

51+
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
52+
add_definitions("-DENABLE_BF16")
53+
message(STATUS "CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} "
54+
"is greater or equal than 11.0, enable -DENABLE_BF16 flag.")
55+
endif()
56+
5157
message(STATUS "tilefusion: CUDA detected: " ${CUDA_VERSION})
5258
message(STATUS "tilefusion: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
5359
message(STATUS "tilefusion: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})

pyproject.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,26 @@ classifiers = [
1515
"Operating System :: OS Independent",
1616
"Topic :: Software Development :: Libraries",
1717
]
18+
# NOTE: setuptools's `install_requires` can overwritten in
19+
# `pyproject.toml`'s `dependencies` field.
20+
# Make sure to keep this field in sync with what is in `requirements.txt`.
21+
dependencies = [
22+
"torch",
23+
]
1824

1925
[project.urls]
2026
Homepage = "https://github.com/microsoft/TileFusion"
2127
Issues = "https://github.com/microsoft/TileFusion/issues"
2228

29+
[build-system]
30+
requires = [
31+
"cmake",
32+
"packaging",
33+
"setuptools>=49.4.0",
34+
"wheel",
35+
]
36+
build-backend = "setuptools.build_meta"
37+
2338
[tool.ruff]
2439
line-length = 80
2540
exclude = [

pytilefusion/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import torch
77

8-
torch.ops.load_library("build/src/libtilefusion.so")
9-
108

119
def scatter_nd(scatter_data, scatter_indices, scatter_updates):
1210
torch.ops.tilefusion.scatter_nd(

pytilefusion/__version__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = '0.0.0'

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
cmake
2+
packaging
3+
setuptools>=49.4.0
4+
torch
5+
wheel

setup.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
import os
7+
import subprocess
8+
from pathlib import Path
9+
10+
from setuptools import Command, Extension, find_packages, setup
11+
from setuptools.command.build_ext import build_ext
12+
13+
cur_path = Path(__file__).parent
14+
15+
16+
def get_requirements():
17+
"""Get Python package dependencies from requirements.txt."""
18+
with open(cur_path / "requirements.txt") as f:
19+
requirements = f.read().strip().split("\n")
20+
requirements = [req for req in requirements if "https" not in req]
21+
return requirements
22+
23+
24+
class CMakeExtension(Extension):
25+
""" specify the root folder of the CMake projects"""
26+
27+
def __init__(self, name, cmake_lists_dir=".", **kwargs):
28+
Extension.__init__(self, name, sources=[], **kwargs)
29+
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
30+
31+
32+
class CMakeBuildExt(build_ext):
33+
"""launches the CMake build."""
34+
35+
def copy_extensions_to_source(self) -> None:
36+
build_py = self.get_finalized_command("build_py")
37+
for ext in self.extensions:
38+
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
39+
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)
40+
41+
target_path = os.path.join(
42+
build_py.build_lib, "pytilefusion", inplace_file
43+
)
44+
45+
# Always copy, even if source is older than destination, to ensure
46+
# that the right extensions for the current Python/platform are
47+
# used.
48+
if os.path.exists(source_path) or not ext.optional:
49+
self.copy_file(source_path, target_path, level=self.verbose)
50+
51+
def build_extension(self, ext: CMakeExtension) -> None:
52+
# Ensure that CMake is present and working
53+
try:
54+
subprocess.check_output(["cmake", "--version"])
55+
except OSError:
56+
raise RuntimeError("Cannot find CMake executable") from None
57+
58+
debug = int(
59+
os.environ.get("DEBUG", 0)
60+
) if self.debug is None else self.debug
61+
cfg = "Debug" if debug else "Release"
62+
63+
parallel_level = os.environ.get("CMAKE_BUILD_PARALLEL_LEVEL", None)
64+
if parallel_level is not None:
65+
self.parallel = int(parallel_level)
66+
else:
67+
self.parallel = os.cpu_count()
68+
69+
for ext in self.extensions:
70+
extdir = os.path.abspath(
71+
os.path.dirname(self.get_ext_fullpath(ext.name))
72+
)
73+
74+
cmake_args = [
75+
"-DCMAKE_BUILD_TYPE=%s" % cfg,
76+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(
77+
cfg.upper(), extdir
78+
), "-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{}={}".format(
79+
cfg.upper(), self.build_temp
80+
)
81+
]
82+
83+
# Adding CMake arguments set as environment variable
84+
if "CMAKE_ARGS" in os.environ:
85+
cmake_args += [
86+
item for item in os.environ["CMAKE_ARGS"].split(" ") if item
87+
]
88+
89+
if not os.path.exists(self.build_temp):
90+
os.makedirs(self.build_temp)
91+
92+
build_args = []
93+
build_args += ["--config", cfg]
94+
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
95+
# across all generators.
96+
if (
97+
"CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ and
98+
hasattr(self, "parallel") and self.parallel
99+
):
100+
build_args += [f"-j{self.parallel}"]
101+
102+
build_temp = Path(self.build_temp) / ext.name
103+
if not build_temp.exists():
104+
build_temp.mkdir(parents=True)
105+
106+
# Config
107+
subprocess.check_call(["cmake", ext.cmake_lists_dir] + cmake_args,
108+
cwd=self.build_temp)
109+
110+
# Build
111+
subprocess.check_call(["cmake", "--build", "."] + build_args,
112+
cwd=self.build_temp)
113+
114+
print()
115+
self.copy_extensions_to_source()
116+
117+
118+
class Clean(Command):
119+
user_options = []
120+
121+
def initialize_options(self):
122+
pass
123+
124+
def finalize_options(self):
125+
pass
126+
127+
def run(self):
128+
import glob
129+
import re
130+
import shutil
131+
132+
with open(".gitignore") as f:
133+
ignores = f.read()
134+
pat = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?")
135+
for wildcard in filter(None, ignores.split("\n")):
136+
match = pat.match(wildcard)
137+
if match:
138+
if match.group(1):
139+
# Marker is found and stop reading .gitignore.
140+
break
141+
# Ignore lines which begin with '#'.
142+
else:
143+
# Don't remove absolute paths from the system
144+
wildcard = wildcard.lstrip("./")
145+
146+
for filename in glob.glob(wildcard):
147+
print(f"cleaning '{filename}'")
148+
try:
149+
os.remove(filename)
150+
except OSError:
151+
shutil.rmtree(filename, ignore_errors=True)
152+
153+
154+
description = ("PyTileFusion: A Python wrapper for tilefusion C++ library.")
155+
156+
with open(os.path.join("pytilefusion", "__version__.py")) as f:
157+
exec(f.read())
158+
159+
setup(
160+
name="tilefusion",
161+
version=__version__, # noqa F821
162+
description=description,
163+
author="Ying Cao, Chengxiang Qi",
164+
python_requires=">=3.10",
165+
packages=find_packages(exclude=[""]),
166+
install_requires=get_requirements(),
167+
ext_modules=[CMakeExtension("tilefusion")],
168+
cmdclass={
169+
"build_ext": CMakeBuildExt,
170+
"clean": Clean,
171+
},
172+
)

src/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ set_target_properties(
2121
CUDA_SEPARABLE_COMPILATION ON)
2222

2323
target_compile_options(
24-
${TARGET} PUBLIC $<$<COMPILE_LANGUAGE:CUDA>: -Werror,-Wall -rdc=true
25-
-std=c++20 -fconcepts -fpermissive>)
24+
${TARGET}
25+
PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:
26+
-Werror,-Wall
27+
-rdc=true
28+
-std=c++20
29+
-fconcepts
30+
-fpermissive
31+
--use_fast_math
32+
--generate-line-info
33+
>)
2634
target_compile_features(${TARGET} PUBLIC cxx_std_20 cuda_std_20)
2735
target_link_libraries(${TARGET} "${TORCH_LIBRARIES}")

tests/python/context.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

tests/python/test_flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import unittest
77

8-
import context # noqa: F401
98
import torch
109

1110
from pytilefusion import TiledFlashAttention

tests/python/test_scatter_nd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import random
77
import unittest
88

9-
import context # noqa: F401
109
import torch
1110

1211
from pytilefusion import scatter_nd

0 commit comments

Comments
 (0)