diff --git a/pytilefusion/__init__.py b/pytilefusion/__init__.py index 90d3ce4..1309ce6 100644 --- a/pytilefusion/__init__.py +++ b/pytilefusion/__init__.py @@ -3,17 +3,32 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os + import torch +def _load_library(filename: str) -> bool: + """Load a shared library from the given filename.""" + try: + libdir = os.path.dirname(os.path.dirname(__file__)) + torch.ops.load_library(os.path.join(libdir, filename)) + print(f"Successfully loaded: '{filename}'") + except Exception as error: + print(f"Fail to load library: '{filename}', {error}\n") + + +_load_library("libtilefusion.so") + + def scatter_nd(scatter_data, scatter_indices, scatter_updates): - torch.ops.tilefusion.scatter_nd( + torch.ops.tilefusion_ops.scatter_nd( scatter_data, scatter_updates, scatter_indices ) def flash_attention_fwd(Q, K, V, Out, m, n, k, p): - torch.ops.tilefusion.flash_attention_fwd(Q, K, V, Out, m, n, k, p) + torch.ops.tilefusion_ops.flash_attention_fwd(Q, K, V, Out, m, n, k, p) class TiledFlashAttention(): diff --git a/setup.py b/setup.py index 4820566..f2ca2c3 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def get_requirements(): class CMakeExtension(Extension): """ specify the root folder of the CMake projects""" - def __init__(self, name, cmake_lists_dir=".", **kwargs): + def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs): Extension.__init__(self, name, sources=[], **kwargs) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -32,10 +32,15 @@ def __init__(self, name, cmake_lists_dir=".", **kwargs): class CMakeBuildExt(build_ext): """launches the CMake build.""" + def get_ext_filename(self, name): + return f"lib{name}.so" + def copy_extensions_to_source(self) -> None: build_py = self.get_finalized_command("build_py") for ext in self.extensions: - source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so") + source_path = os.path.join( + self.build_lib, self.get_ext_filename(ext.name) + ) inplace_file, _ = self._get_inplace_equivalent(build_py, ext) target_path = os.path.join( @@ -164,7 +169,7 @@ def run(self): python_requires=">=3.10", packages=find_packages(exclude=[""]), install_requires=get_requirements(), - ext_modules=[CMakeExtension("tilefusion")], + ext_modules=[CMakeExtension()], cmdclass={ "build_ext": CMakeBuildExt, "clean": Clean, diff --git a/src/torch_bind.cc b/src/torch_bind.cc index 76412a1..3d3baad 100644 --- a/src/torch_bind.cc +++ b/src/torch_bind.cc @@ -3,14 +3,11 @@ #include "kernels/mod.hpp" -#include - namespace tilefusion { using namespace tilefusion::kernels; -TORCH_LIBRARY(tilefusion, t) { +TORCH_LIBRARY(tilefusion_ops, t) { t.def("scatter_nd", &custom_scatter_op); t.def("flash_attention_fwd", &custom_flash_attention_op); }; - } // namespace tilefusion