Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): bug fix for loading built library. #29

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/kernels/flash_attn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ template <typename InType, typename AccType, typename OutType,
void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
OutType* dO);

void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O,
int64_t m, int64_t n, int64_t k, int64_t p);
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O, int64_t m,
int64_t n, int64_t k, int64_t p);

} // namespace tilefusion::kernels
4 changes: 2 additions & 2 deletions include/kernels/scatter_nd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template <typename T>
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

} // namespace tilefusion::kernels
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Issues = "https://github.com/microsoft/TileFusion/issues"
requires = [
"cmake",
"packaging",
"setuptools>=49.4.0",
"setuptools>=64.0.0",
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
15 changes: 15 additions & 0 deletions pytilefusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,24 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os

import torch


def _load_library(filename: str) -> bool:
"""Load a shared library from the given filename."""
try:
libdir = os.path.dirname(os.path.dirname(__file__))
torch.ops.load_library(os.path.join(libdir, "pytilefusion", filename))
print(f"Successfully loaded: '{filename}'")
except Exception as error:
print(f"Fail to load library: '{filename}', {error}\n")


_load_library("libtilefusion.so")


def scatter_nd(scatter_data, scatter_indices, scatter_updates):
torch.ops.tilefusion.scatter_nd(
scatter_data, scatter_updates, scatter_indices
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake
packaging
setuptools>=49.4.0
setuptools>=64.0.0
torch
wheel
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ def get_requirements():
class CMakeExtension(Extension):
""" specify the root folder of the CMake projects"""

def __init__(self, name, cmake_lists_dir=".", **kwargs):
def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs):
Extension.__init__(self, name, sources=[], **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(
Expand Down Expand Up @@ -164,7 +169,7 @@ def run(self):
python_requires=">=3.10",
packages=find_packages(exclude=[""]),
install_requires=get_requirements(),
ext_modules=[CMakeExtension("tilefusion")],
ext_modules=[CMakeExtension()],
cmdclass={
"build_ext": CMakeBuildExt,
"clean": Clean,
Expand Down
6 changes: 3 additions & 3 deletions src/kernels/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
cudaDeviceSynchronize();
}

void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O,
int64_t m, int64_t n, int64_t k, int64_t p) {
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O, int64_t m,
int64_t n, int64_t k, int64_t p) {
using InType = __half;
using AccType = float;
using OutType = __half;
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/scatter_nd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
slice_size);
}

void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
auto dtype = data.dtype();
if (dtype == torch::kFloat32) {
scatter_nd<float>(data, updates, indices);
Expand Down
17 changes: 12 additions & 5 deletions src/torch_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@

#include "kernels/mod.hpp"

#include <torch/script.h>

namespace tilefusion {
using namespace tilefusion::kernels;

TORCH_LIBRARY(tilefusion, t) {
t.def("scatter_nd", &custom_scatter_op);
t.def("flash_attention_fwd", &custom_flash_attention_op);
TORCH_LIBRARY_IMPL(tilefusion, CUDA, m) {
m.impl("scatter_nd", scatter_op);
m.impl("flash_attention_fwd", flash_attention_op);
};

TORCH_LIBRARY(tilefusion, m) {
m.def("scatter_nd(Tensor(a!) data, Tensor updates, Tensor indices) -> ()");
m.def(
R"DOC(flash_attention_fwd(
Tensor(a!) Q,
Tensor K, Tensor V, Tensor O,
int m, int n, int k, int p) -> ()
)DOC");
}
} // namespace tilefusion
Loading