Skip to content

Commit 6c04495

Browse files
committed
bug fix.
1 parent 981a39a commit 6c04495

File tree

9 files changed

+47
-20
lines changed

9 files changed

+47
-20
lines changed

include/kernels/flash_attn.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ template <typename InType, typename AccType, typename OutType,
3535
void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
3636
OutType* dO);
3737

38-
void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
39-
const torch::Tensor& V, torch::Tensor& O,
40-
int64_t m, int64_t n, int64_t k, int64_t p);
38+
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
39+
const torch::Tensor& V, torch::Tensor& O, int64_t m,
40+
int64_t n, int64_t k, int64_t p);
4141

4242
} // namespace tilefusion::kernels

include/kernels/scatter_nd.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ template <typename T>
3939
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
4040
const torch::Tensor& indices);
4141

42-
void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
43-
const torch::Tensor& indices);
42+
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
43+
const torch::Tensor& indices);
4444

4545
} // namespace tilefusion::kernels

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Issues = "https://github.com/microsoft/TileFusion/issues"
3030
requires = [
3131
"cmake",
3232
"packaging",
33-
"setuptools>=49.4.0",
33+
"setuptools>=64.0.0",
3434
"wheel",
3535
]
3636
build-backend = "setuptools.build_meta"

pytilefusion/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,24 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6+
import os
7+
68
import torch
79

810

11+
def _load_library(filename: str) -> bool:
12+
"""Load a shared library from the given filename."""
13+
try:
14+
libdir = os.path.dirname(os.path.dirname(__file__))
15+
torch.ops.load_library(os.path.join(libdir, "pytilefusion", filename))
16+
print(f"Successfully loaded: '{filename}'")
17+
except Exception as error:
18+
print(f"Fail to load library: '{filename}', {error}\n")
19+
20+
21+
_load_library("libtilefusion.so")
22+
23+
924
def scatter_nd(scatter_data, scatter_indices, scatter_updates):
1025
torch.ops.tilefusion.scatter_nd(
1126
scatter_data, scatter_updates, scatter_indices

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cmake
22
packaging
3-
setuptools>=49.4.0
3+
setuptools>=64.0.0
44
torch
55
wheel

setup.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,23 @@ def get_requirements():
2424
class CMakeExtension(Extension):
2525
""" specify the root folder of the CMake projects"""
2626

27-
def __init__(self, name, cmake_lists_dir=".", **kwargs):
27+
def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs):
2828
Extension.__init__(self, name, sources=[], **kwargs)
2929
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
3030

3131

3232
class CMakeBuildExt(build_ext):
3333
"""launches the CMake build."""
3434

35+
def get_ext_filename(self, name):
36+
return f"lib{name}.so"
37+
3538
def copy_extensions_to_source(self) -> None:
3639
build_py = self.get_finalized_command("build_py")
3740
for ext in self.extensions:
38-
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
41+
source_path = os.path.join(
42+
self.build_lib, self.get_ext_filename(ext.name)
43+
)
3944
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)
4045

4146
target_path = os.path.join(
@@ -164,7 +169,7 @@ def run(self):
164169
python_requires=">=3.10",
165170
packages=find_packages(exclude=[""]),
166171
install_requires=get_requirements(),
167-
ext_modules=[CMakeExtension("tilefusion")],
172+
ext_modules=[CMakeExtension()],
168173
cmdclass={
169174
"build_ext": CMakeBuildExt,
170175
"clean": Clean,

src/kernels/flash_attn.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
424424
cudaDeviceSynchronize();
425425
}
426426

427-
void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
428-
const torch::Tensor& V, torch::Tensor& O,
429-
int64_t m, int64_t n, int64_t k, int64_t p) {
427+
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
428+
const torch::Tensor& V, torch::Tensor& O, int64_t m,
429+
int64_t n, int64_t k, int64_t p) {
430430
using InType = __half;
431431
using AccType = float;
432432
using OutType = __half;

src/kernels/scatter_nd.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
114114
slice_size);
115115
}
116116

117-
void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
118-
const torch::Tensor& indices) {
117+
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
118+
const torch::Tensor& indices) {
119119
auto dtype = data.dtype();
120120
if (dtype == torch::kFloat32) {
121121
scatter_nd<float>(data, updates, indices);

src/torch_bind.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33

44
#include "kernels/mod.hpp"
55

6-
#include <torch/script.h>
7-
86
namespace tilefusion {
97
using namespace tilefusion::kernels;
108

11-
TORCH_LIBRARY(tilefusion, t) {
12-
t.def("scatter_nd", &custom_scatter_op);
13-
t.def("flash_attention_fwd", &custom_flash_attention_op);
9+
TORCH_LIBRARY_IMPL(tilefusion, CUDA, m) {
10+
m.impl("scatter_nd", scatter_op);
11+
m.impl("flash_attention_fwd", flash_attention_op);
1412
};
1513

14+
TORCH_LIBRARY(tilefusion, m) {
15+
m.def("scatter_nd(Tensor(a!) data, Tensor updates, Tensor indices) -> ()");
16+
m.def(
17+
R"DOC(flash_attention_fwd(
18+
Tensor(a!) Q,
19+
Tensor K, Tensor V, Tensor O,
20+
int m, int n, int k, int p) -> ()
21+
)DOC");
22+
}
1623
} // namespace tilefusion

0 commit comments

Comments
 (0)