diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..33fb987 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,56 @@ +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.11.5 + hooks: + - id: isort + args: ["--multi-line=7", "--sl", "--profile", "black", "--filter-files"] + + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.0.272" + hooks: + - id: ruff + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*paddle)^.*$ + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-case-conflict + - id: check-yaml + exclude: "mkdocs.yml|recipe/meta.yaml" + - id: pretty-format-json + args: [--autofix] + - id: requirements-txt-fixer + + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.0.1 + hooks: + - id: forbid-crlf + files: \.md$ + - id: remove-crlf + files: \.md$ + - id: forbid-tabs + files: \.md$ + - id: remove-tabs + files: \.md$ + + - repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat + entry: bash .clang_format.hook -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ + +exclude: | + ^jointContribution/ diff --git a/setup_ops.py b/setup_ops.py index cf9f270..f6ed43f 100644 --- a/setup_ops.py +++ b/setup_ops.py @@ -1,12 +1,63 @@ import os -import re import os.path as osp +import re import paddle from paddle.utils.cpp_extension import CppExtension from paddle.utils.cpp_extension import CUDAExtension from paddle.utils.cpp_extension import setup +PADDLE_PATH = os.path.dirname(paddle.__file__) +PADDLE_INCLUDE_PATH = os.path.join(PADDLE_PATH, "include") +PADDLE_LIB_PATH = os.path.join(PADDLE_PATH, "libs") +BASE_DIR = "/workspace/wangguan12/xpu" +os.environ["XHPC_PATH"] = BASE_DIR + "/xhpc-ubuntu2004_x86_64" +os.environ["XRE_PATH"] = BASE_DIR + "/xre-Linux-x86_64-5.0.21.22" +os.environ["CLANG_PATH"] = BASE_DIR + "/xtdk-llvm15-ubuntu2004_x86_64" +os.environ["BKCL_PATH"] = BASE_DIR + "/xccl_rdma-ubuntu_x86_64" +# os.environ['XFT_PATH'] = os.environ['XHPC_PATH'] # XFT在XHPC目录下 +# os.environ['XBLAS_PATH'] = os.environ['XHPC_PATH'] # XBLAS在XHPC目录下 + +BKCL_PATH = os.getenv("BKCL_PATH") +if BKCL_PATH is None: + BKCL_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xpu") + BKCL_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libbkcl.so") +else: + BKCL_INC_PATH = os.path.join(BKCL_PATH, "include") + BKCL_LIB_PATH = os.path.join(BKCL_PATH, "so", "libbkcl.so") + +# XFT_PATH = os.getenv("XFT_PATH") +# if XFT_PATH is None: +# XFT_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xft") +# XFT_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxft.so") +# else: +# XFT_INC_PATH = os.path.join(XFT_PATH, "include") +# XFT_LIB_PATH = os.path.join(XFT_PATH, "so", "libxft.so") + +XRE_PATH = os.getenv("XRE_PATH") +if XRE_PATH is None: + XRE_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xre") + XRE_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpucuda.so") +else: + XRE_INC_PATH = os.path.join(XRE_PATH, "include") + XRE_LIB_PATH = os.path.join(XRE_PATH, "so", "libxpucuda.so") + +# XFA_PATH = os.getenv("XFA_PATH") +# if XFA_PATH is None: +# XFA_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xhpc", "xfa") +# XFA_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpu_flash_attention.so") +# else: +# XFA_INC_PATH = os.path.join(XFA_PATH, "include") +# XFA_LIB_PATH = os.path.join(XFA_PATH, "so", "libxpu_flash_attention.so") + +# XBLAS_PATH = os.getenv("XBLAS_PATH") +# if XBLAS_PATH is None: +# XBLAS_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xhpc", "xblas") +# XBLAS_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpu_blas.so") +# else: +# XBLAS_INC_PATH = os.path.join(XBLAS_PATH, "include") +# XBLAS_LIB_PATH = os.path.join(XBLAS_PATH, "so", "libxpu_blas.so") + def get_version(): current_dir = osp.dirname(osp.abspath(__file__)) @@ -18,6 +69,7 @@ def get_version(): raise RuntimeError("Cannot find __version__ in paddle_scatter/__init__.py") + __version__ = get_version() @@ -47,12 +99,15 @@ def get_sources(): else: if item.endswith(".cc"): cpp_files.append(os.path.join(csrc_dir_path, item)) - return csrc_dir_path, cpp_files + return [csrc_dir_path], cpp_files def get_extensions(): Extension = CppExtension - extra_compile_args = {'cxx': ['-O3']} + extra_objects = [] + include_dirs, sources = get_sources() + + extra_compile_args = {"cxx": ["-O3"]} if paddle.device.is_compiled_with_cuda(): set_cuda_archs() Extension = CUDAExtension @@ -61,12 +116,30 @@ def get_extensions(): nvcc_flags += ["-O3"] nvcc_flags += ["--expt-relaxed-constexpr"] extra_compile_args["nvcc"] = nvcc_flags + elif paddle.device.is_compiled_with_xpu(): + include_dirs += [ + XRE_INC_PATH, + # XFT_INC_PATH, + BKCL_LIB_PATH, + # XFA_INC_PATH, + # XBLAS_INC_PATH, + ] + extra_objects += [ + XRE_LIB_PATH, + # XFT_LIB_PATH, + BKCL_LIB_PATH, + # XFA_LIB_PATH, + # XBLAS_LIB_PATH, + ] + extra_compile_args["cxx"] = ["-D_GLIBCXX_USE_CXX11_ABI=1", "-DPADDLE_WITH_XPU"] + else: + raise ("Only CUDA and XPU devices are supported") - src = get_sources() ext_modules = [ Extension( - sources=src[1], - include_dirs=src[0], + sources=sources, + include_dirs=include_dirs, + extra_objects=extra_objects, extra_compile_args=extra_compile_args, ) ] @@ -80,6 +153,8 @@ def get_extensions(): version=__version__, author="NKNaN", url="https://github.com/PFCCLab/paddle_scatter", - description="Paddle extension of scatter and segment operators with min and max reduction methods, originally from https://github.com/rusty1s/pytorch_scatter", + description="Paddle extension of scatter and segment operators \ + with min and max reduction methods, \ + originally from https://github.com/rusty1s/pytorch_scatter", ext_modules=get_extensions(), )