Skip to content

Commit

Permalink
fix setup.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Dec 27, 2024
1 parent 5dcc3e7 commit 9a2a918
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 16 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ cd TileFusion && git submodule update --init --recursive

TileFusion requires a C++20 host compiler, CUDA 12.0 or later, and GCC version 10.0 or higher to support C++20 features.

### Build from Source

#### Using Makefile
To build the project using the provided `Makefile`, simply run:
```bash
make
```

#### Building the Python Wrapper

1. Build the wheel:
```bash
python3 setup.py build bdist_wheel
```

2. Clean the build:
```bash
python3 setup.py clean
```

### Unit Test

- **Run a single unit test**: `make unit_test UNIT_TEST=test_scatter_nd.py`
Expand Down
2 changes: 0 additions & 2 deletions pytilefusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import torch

torch.ops.load_library("build/src/libtilefusion.so")


def scatter_nd(scatter_data, scatter_indices, scatter_updates):
torch.ops.tilefusion.scatter_nd(
Expand Down
21 changes: 20 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def __init__(self, name, cmake_lists_dir=".", **kwargs):
class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(
build_py.build_lib, "pytilefusion", inplace_file
)

# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
# used.
if os.path.exists(source_path) or not ext.optional:
self.copy_file(source_path, target_path, level=self.verbose)

def build_extension(self, ext: CMakeExtension) -> None:
# Ensure that CMake is present and working
try:
Expand Down Expand Up @@ -95,6 +111,9 @@ def build_extension(self, ext: CMakeExtension) -> None:
subprocess.check_call(["cmake", "--build", "."] + build_args,
cwd=self.build_temp)

print()
self.copy_extensions_to_source()


class clean(Command):
user_options = []
Expand Down Expand Up @@ -134,7 +153,7 @@ def run(self):

description = ("PyTileFusion: A Python wrapper for tilefusion C++ library.")

with open(os.path.join("pytilefusion", '__version__.py')) as f:
with open(os.path.join("pytilefusion", "__version__.py")) as f:
exec(f.read())

setup(
Expand Down
11 changes: 0 additions & 11 deletions tests/python/context.py

This file was deleted.

1 change: 0 additions & 1 deletion tests/python/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import unittest

import context # noqa: F401
import torch

from pytilefusion import TiledFlashAttention
Expand Down
1 change: 0 additions & 1 deletion tests/python/test_scatter_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import random
import unittest

import context # noqa: F401
import torch

from pytilefusion import scatter_nd
Expand Down

0 comments on commit 9a2a918

Please sign in to comment.