Skip to content

Commit

Permalink
bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Dec 31, 2024
1 parent 981a39a commit 681c0f5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
19 changes: 17 additions & 2 deletions pytilefusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,32 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os

import torch


def _load_library(filename: str) -> bool:
"""Load a shared library from the given filename."""
try:
libdir = os.path.dirname(os.path.dirname(__file__))
torch.ops.load_library(os.path.join(libdir, filename))
print(f"Successfully loaded: '{filename}'")
except Exception as error:
print(f"Fail to load library: '{filename}', {error}\n")


_load_library("libtilefusion.so")


def scatter_nd(scatter_data, scatter_indices, scatter_updates):
torch.ops.tilefusion.scatter_nd(
torch.ops.tilefusion_ops.scatter_nd(
scatter_data, scatter_updates, scatter_indices
)


def flash_attention_fwd(Q, K, V, Out, m, n, k, p):
torch.ops.tilefusion.flash_attention_fwd(Q, K, V, Out, m, n, k, p)
torch.ops.tilefusion_ops.flash_attention_fwd(Q, K, V, Out, m, n, k, p)


class TiledFlashAttention():
Expand Down
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ def get_requirements():
class CMakeExtension(Extension):
""" specify the root folder of the CMake projects"""

def __init__(self, name, cmake_lists_dir=".", **kwargs):
def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs):
Extension.__init__(self, name, sources=[], **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(
Expand Down Expand Up @@ -164,7 +169,7 @@ def run(self):
python_requires=">=3.10",
packages=find_packages(exclude=[""]),
install_requires=get_requirements(),
ext_modules=[CMakeExtension("tilefusion")],
ext_modules=[CMakeExtension()],
cmdclass={
"build_ext": CMakeBuildExt,
"clean": Clean,
Expand Down
5 changes: 1 addition & 4 deletions src/torch_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@

#include "kernels/mod.hpp"

#include <torch/script.h>

namespace tilefusion {
using namespace tilefusion::kernels;

TORCH_LIBRARY(tilefusion, t) {
TORCH_LIBRARY(tilefusion_ops, t) {
t.def("scatter_nd", &custom_scatter_op);
t.def("flash_attention_fwd", &custom_flash_attention_op);
};

} // namespace tilefusion

0 comments on commit 681c0f5

Please sign in to comment.