diff --git a/include/kernels/flash_attn.hpp b/include/kernels/flash_attn.hpp index 46b1170..27b2c33 100644 --- a/include/kernels/flash_attn.hpp +++ b/include/kernels/flash_attn.hpp @@ -35,8 +35,8 @@ template void scatter_nd(torch::Tensor& data, const torch::Tensor& updates, const torch::Tensor& indices); -void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates, - const torch::Tensor& indices); +void scatter_op(torch::Tensor& data, const torch::Tensor& updates, + const torch::Tensor& indices); } // namespace tilefusion::kernels diff --git a/pyproject.toml b/pyproject.toml index 9945f48..e34a429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ Issues = "https://github.com/microsoft/TileFusion/issues" requires = [ "cmake", "packaging", - "setuptools>=49.4.0", + "setuptools>=64.0.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/pytilefusion/__init__.py b/pytilefusion/__init__.py index 90d3ce4..98e771f 100644 --- a/pytilefusion/__init__.py +++ b/pytilefusion/__init__.py @@ -3,9 +3,24 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os + import torch +def _load_library(filename: str) -> bool: + """Load a shared library from the given filename.""" + try: + libdir = os.path.dirname(os.path.dirname(__file__)) + torch.ops.load_library(os.path.join(libdir, "pytilefusion", filename)) + print(f"Successfully loaded: '{filename}'") + except Exception as error: + print(f"Fail to load library: '{filename}', {error}\n") + + +_load_library("libtilefusion.so") + + def scatter_nd(scatter_data, scatter_indices, scatter_updates): torch.ops.tilefusion.scatter_nd( scatter_data, scatter_updates, scatter_indices diff --git a/requirements.txt b/requirements.txt index 3fdbc00..d08b279 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ cmake packaging -setuptools>=49.4.0 +setuptools>=64.0.0 torch wheel diff --git a/setup.py b/setup.py index 4820566..f2ca2c3 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def get_requirements(): class CMakeExtension(Extension): """ specify the root folder of the CMake projects""" - def __init__(self, name, cmake_lists_dir=".", **kwargs): + def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs): Extension.__init__(self, name, sources=[], **kwargs) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -32,10 +32,15 @@ def __init__(self, name, cmake_lists_dir=".", **kwargs): class CMakeBuildExt(build_ext): """launches the CMake build.""" + def get_ext_filename(self, name): + return f"lib{name}.so" + def copy_extensions_to_source(self) -> None: build_py = self.get_finalized_command("build_py") for ext in self.extensions: - source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so") + source_path = os.path.join( + self.build_lib, self.get_ext_filename(ext.name) + ) inplace_file, _ = self._get_inplace_equivalent(build_py, ext) target_path = os.path.join( @@ -164,7 +169,7 @@ def run(self): python_requires=">=3.10", packages=find_packages(exclude=[""]), install_requires=get_requirements(), - ext_modules=[CMakeExtension("tilefusion")], + ext_modules=[CMakeExtension()], cmdclass={ "build_ext": CMakeBuildExt, "clean": Clean, diff --git a/src/kernels/flash_attn.cu b/src/kernels/flash_attn.cu index 102f1d1..abc09e3 100644 --- a/src/kernels/flash_attn.cu +++ b/src/kernels/flash_attn.cu @@ -424,9 +424,9 @@ void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV, cudaDeviceSynchronize(); } -void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K, - const torch::Tensor& V, torch::Tensor& O, - int64_t m, int64_t n, int64_t k, int64_t p) { +void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K, + const torch::Tensor& V, torch::Tensor& O, int64_t m, + int64_t n, int64_t k, int64_t p) { using InType = __half; using AccType = float; using OutType = __half; diff --git a/src/kernels/scatter_nd.cu b/src/kernels/scatter_nd.cu index 1b491d9..e91ab82 100644 --- a/src/kernels/scatter_nd.cu +++ b/src/kernels/scatter_nd.cu @@ -114,8 +114,8 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates, slice_size); } -void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates, - const torch::Tensor& indices) { +void scatter_op(torch::Tensor& data, const torch::Tensor& updates, + const torch::Tensor& indices) { auto dtype = data.dtype(); if (dtype == torch::kFloat32) { scatter_nd(data, updates, indices); diff --git a/src/torch_bind.cc b/src/torch_bind.cc index 76412a1..51778ef 100644 --- a/src/torch_bind.cc +++ b/src/torch_bind.cc @@ -3,14 +3,21 @@ #include "kernels/mod.hpp" -#include - namespace tilefusion { using namespace tilefusion::kernels; -TORCH_LIBRARY(tilefusion, t) { - t.def("scatter_nd", &custom_scatter_op); - t.def("flash_attention_fwd", &custom_flash_attention_op); +TORCH_LIBRARY_IMPL(tilefusion, CUDA, m) { + m.impl("scatter_nd", scatter_op); + m.impl("flash_attention_fwd", flash_attention_op); }; +TORCH_LIBRARY(tilefusion, m) { + m.def("scatter_nd(Tensor(a!) data, Tensor updates, Tensor indices) -> ()"); + m.def( + R"DOC(flash_attention_fwd( + Tensor(a!) Q, + Tensor K, Tensor V, Tensor O, + int m, int n, int k, int p) -> () + )DOC"); +} } // namespace tilefusion