diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index ea980b59b..86c9ab022 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -9,7 +9,7 @@ - **CUDA Version**: >= 12.1 - **LLVM**: < 20 if you are using the bundled TVM submodule -We currently provide three methods to install **TileScale**: +Install **TileScale** with the following steps: **(optional)Prepare the container**: @@ -48,32 +48,41 @@ You can now run TileScale examples and develop your applications. **Example Usage:** -You can run TileScale examples: +From the project root: ```bash -cd /home/tilelang TILELANG_USE_DISTRIBUTED=1 python examples/distributed/example_allgather_gemm_overlapped.py ``` ## To use NVSHMEM APIs -Before running the examples using NVSHMEM APIs (e.g., [example_allgather.py](../../examples/distributed/example_allgather.py)), you need to build NVSHMEM library for device-side code generation. +Device-side code generation (kernels calling `nvshmem_*` on the GPU) requires NVSHMEM built from source (the pip package does not provide `libnvshmem_device`). Build from source and install the Python bindings as follows. + +**1. Build NVSHMEM from source** ```bash -pip install mpich # building NVSHMEM needs MPI -export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src +pip install mpich # NVSHMEM build requires MPI cd tilelang/distributed -source build_nvshmem.sh +# Optional: set NVSHMEM_SRC to a custom path; default is ../../3rdparty/nvshmem_src +# For H100 (sm_90), add: bash build_nvshmem.sh --arch 90 +bash build_nvshmem.sh +# Then set the env vars printed at the end (NVSHMEM_SRC, LD_LIBRARY_PATH). ``` -You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM. + +The script downloads the NVSHMEM source tarball from NVIDIA; you may need to be logged in at [NVIDIA Developer](https://developer.nvidia.com) for the download to succeed. + +**2. Install pynvshmem (host-side Python API)** + +From the project root (ensure `NVSHMEM_SRC` is set, e.g. from step 1 in the same shell): ```bash -cd ./pynvshmem +cd tilelang/distributed/pynvshmem python setup.py install -export LD_LIBRARY_PATH="$NVSHMEM_SRC/build/src/lib:$LD_LIBRARY_PATH" +export LD_LIBRARY_PATH="${NVSHMEM_SRC}/build/src/lib:$LD_LIBRARY_PATH" ``` -Then you can test python import: +**3. Verify** + ```bash python -c "import pynvshmem" ``` diff --git a/examples/distributed/README.md b/examples/distributed/README.md index 48cf85488..73b435d1b 100644 --- a/examples/distributed/README.md +++ b/examples/distributed/README.md @@ -9,22 +9,33 @@ For example, ## Prerequisites -Before running the examples, you need to build NVSHMEM library for device-side code generation. +Before running the examples, you need NVSHMEM (either from source or the pip package) and the `pynvshmem` Python bindings. + +**Build NVSHMEM from source (from repo root):** ```bash -export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src +pip install mpich cd tilelang/distributed -source build_nvshmem.sh +bash build_nvshmem.sh # optional: --arch 90 for H100 (sm_90). Then set NVSHMEM_SRC and LD_LIBRARY_PATH as printed. ``` -You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM. + +**Or install the prebuilt NVSHMEM package:** ```bash -cd ./pynvshmem +pip install nvidia-nvshmem-cu12 +``` + +**Install pynvshmem and set library path:** + +```bash +cd tilelang/distributed/pynvshmem python setup.py install +# If you built NVSHMEM from source: export LD_LIBRARY_PATH="$NVSHMEM_SRC/build/src/lib:$LD_LIBRARY_PATH" ``` -Then you can test python import: +Then verify: + ```bash python -c "import pynvshmem" ``` diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index ce605ac7d..b71290542 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -307,11 +307,6 @@ std::string CodeGenTileLangCUDA::Finish() { decl_stream << "#include \n"; } - if (use_nvshmem_) { - decl_stream << "#include \n"; - decl_stream << "#include \n"; - } - decl_stream << "#include \n"; if (enable_sparse_gemm_) { decl_stream << "#include \n"; @@ -2851,22 +2846,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } } os << ")"; - } else if (op->op.same_as(tl::Quiet())) { - this->use_distributed_ = true; - this->use_nvshmem_ = true; - os << "nvshmem_quiet()"; - } else if (op->op.same_as(tl::Fence())) { - this->use_distributed_ = true; - this->use_nvshmem_ = true; - os << "nvshmem_fence()"; - } else if (op->op.same_as(tl::SyncAll())) { - this->use_distributed_ = true; - this->use_nvshmem_ = true; - os << "nvshmem_sync_all()"; - } else if (op->op.same_as(tl::BarrierAll())) { - this->use_distributed_ = true; - this->use_nvshmem_ = true; - os << "nvshmem_barrier_all()"; } else if (op->op.same_as(tl::fence_cta())) { this->use_distributed_ = true; os << "tl::memory_fence_cta()"; diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 6c5f89e07..7aacfc9ae 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -19,11 +19,19 @@ * Utility function for judging whether distributed mode is enabled. * This is used to determine whether to include distributed.h in the generated * code. + * + * Accepted truthy values: "1", "true", "on" (case-insensitive), consistent + * with the Python-side Environment.USE_DISTRIBUTED property. */ static inline bool use_distributed() { const char *env = std::getenv("TILELANG_USE_DISTRIBUTED"); if (env) { - return std::string(env) == "1"; + std::string val(env); + + // Convert to lowercase for case-insensitive comparison + for (auto &c : val) + c = std::tolower(c); + return val == "1" || val == "true" || val == "on"; } return false; } diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 0f6ec52eb..2b628c616 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -8,28 +8,7 @@ import subprocess import warnings import contextlib -from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH - - -def _get_nvshmem_include_path(): - """Get NVSHMEM include path from pip-installed nvidia-nvshmem-cu12 or environment.""" - # Try pip-installed nvidia-nvshmem-cu12 - try: - import nvidia.nvshmem - - nvshmem_path = nvidia.nvshmem.__path__[0] - include_path = os.path.join(nvshmem_path, "include") - if os.path.exists(include_path): - return include_path - except ImportError: - pass - # Try environment variable - nvshmem_home = os.environ.get("NVSHMEM_HOME", "") - if nvshmem_home: - include_path = os.path.join(nvshmem_home, "include") - if os.path.exists(include_path): - return include_path - return None +from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH, env as _env import shutil @@ -177,8 +156,8 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str] except Exception: pass - # Add NVSHMEM include path for distributed support - nvshmem_include = _get_nvshmem_include_path() + # Add NVSHMEM include path for distributed support (centralized in env.py) + nvshmem_include = _env.NVSHMEM_INCLUDE_DIR if nvshmem_include: options.append(f"-I{nvshmem_include}") diff --git a/tilelang/distributed/build_nvshmem.sh b/tilelang/distributed/build_nvshmem.sh index 0ed532b25..9f3d2fcb4 100755 --- a/tilelang/distributed/build_nvshmem.sh +++ b/tilelang/distributed/build_nvshmem.sh @@ -1,103 +1,75 @@ #!/bin/bash +# Build NVSHMEM from source for TileLang device-side use. Run from repo root or set NVSHMEM_SRC. +# Usage: source build_nvshmem.sh [--arch 90] [--jobs N] [--force-download] +# Override at runtime: NVSHMEM_SRC, NVSHMEM_* (see below), CMAKE, NVSHMEM_VERSION. -if [ -z "${NVSHMEM_SRC}" ]; then - export NVSHMEM_SRC="$(realpath ../../3rdparty/nvshmem_src)" - echo "NVSHMEM_SRC not set, defaulting to ${NVSHMEM_SRC}" -else - NVSHMEM_SRC="$(realpath ${NVSHMEM_SRC})" - echo "Using NVSHMEM_SRC=${NVSHMEM_SRC}" -fi - -if [ -d "${NVSHMEM_SRC}" ]; then - if [ "$(ls -A ${NVSHMEM_SRC})" ]; then - echo "NVSHMEM_SRC directory (${NVSHMEM_SRC}) is not empty, cleaning it..." - rm -rf "${NVSHMEM_SRC}/"* - rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true - fi -else - mkdir -p "${NVSHMEM_SRC}" -fi - -wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz -tar zxvf nvshmem_src_3.2.5-1.txz -rm -rf nvshmem_src_3.2.5-1.txz - -mkdir -p "${NVSHMEM_SRC}" -mv nvshmem_src/* "${NVSHMEM_SRC}/" -mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true -rmdir nvshmem_src - - -export NVSHMEM_PATH="${NVSHMEM_SRC}" - +VER="${NVSHMEM_VERSION:-3.2.5-1}" +TARBALL="nvshmem_src_${VER}.txz" +URL="${NVSHMEM_SOURCE_URL:-https://developer.nvidia.com/downloads/assets/secure/nvshmem/${TARBALL}}" SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -PROJECT_ROOT=$(realpath "${SCRIPT_DIR}") -echo "SCRIPT_DIR: ${SCRIPT_DIR}" -echo "PROJECT_ROOT: ${PROJECT_ROOT}" -echo "NVSHMEM will be installed to: ${NVSHMEM_SRC}" +FORCE_DL="" ARCH="" JOBS="" - -# Iterate over the command-line arguments while [[ $# -gt 0 ]]; do - key="$1" - - case $key in - --arch) - # Process the arch argument - ARCH="$2" - shift # Skip the argument value - shift # Skip the argument key - ;; - --jobs) - # Process the jobs argument - JOBS="$2" - shift # Skip the argument value - shift # Skip the argument key - ;; - *) - # Unknown argument - echo "Unknown argument: $1" - shift # Skip the argument - ;; - esac + case "$1" in + -h|--help) echo "Usage: bash build_nvshmem.sh [--arch ARCH] [--jobs N] [--force-download]"; exit 0 ;; + --arch) ARCH="$2"; shift 2 ;; + --jobs) JOBS="$2"; shift 2 ;; + --force-download) FORCE_DL=1; shift ;; + *) shift ;; + esac done -if [[ -n "${ARCH}" ]]; then - export CMAKE_CUDA_ARCHITECTURES="${ARCH}" - CUDAARCH_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${ARCH}" -fi - -if [[ -z "${JOBS}" ]]; then - JOBS=$(nproc --ignore 2) -fi - -export NVSHMEM_IBGDA_SUPPORT=0 -export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY=0 -export NVSHMEM_IBDEVX_SUPPORT=0 -export NVSHMEM_IBRC_SUPPORT=1 -export NVSHMEM_LIBFABRIC_SUPPORT=0 -export NVSHMEM_MPI_SUPPORT=1 -export NVSHMEM_USE_GDRCOPY=0 -export NVSHMEM_TORCH_SUPPORT=1 -export NVSHMEM_ENABLE_ALL_DEVICE_INLINING=1 - -pushd "${NVSHMEM_SRC}" -mkdir -p build -cd build -CMAKE=${CMAKE:-cmake} - -if [ ! -f CMakeCache.txt ]; then - ${CMAKE} .. \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ - ${CUDAARCH_ARGS} \ - -DNVSHMEM_BUILD_TESTS=OFF \ - -DNVSHMEM_BUILD_EXAMPLES=OFF \ - -DNVSHMEM_BUILD_PACKAGES=OFF +export NVSHMEM_SRC="$(realpath "${NVSHMEM_SRC:-$SCRIPT_DIR/../../3rdparty/nvshmem_src}")" +export NVSHMEM_PATH="${NVSHMEM_SRC}" +JOBS="${JOBS:-$(nproc 2>/dev/null || echo 4)}" +ARCH="${ARCH:-$CMAKE_CUDA_ARCHITECTURES}" + +if [[ -n "$FORCE_DL" ]] || [[ ! -f "${NVSHMEM_SRC}/CMakeLists.txt" ]]; then + if [[ -f "${NVSHMEM_SRC}/CMakeLists.txt" ]]; then + rm -rf "${NVSHMEM_SRC:?}/"* + rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true + else + mkdir -p "${NVSHMEM_SRC}" + fi + cd "${SCRIPT_DIR}" + [[ -f "${TARBALL}" ]] || { wget -q --show-progress "${URL}" -O "${TARBALL}" || { echo "Download failed (login at developer.nvidia.com?)." >&2; exit 1; }; } + tar -zxf "${TARBALL}" + rm -f "${TARBALL}" + [[ -d nvshmem_src ]] || { echo "Missing nvshmem_src after extract." >&2; exit 1; } + mv nvshmem_src/* "${NVSHMEM_SRC}/" + mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true + rmdir nvshmem_src fi -make VERBOSE=1 -j"${JOBS}" -popd - -echo "NVSHMEM installed successfully to ${NVSHMEM_SRC}" +export NVSHMEM_IBGDA_SUPPORT="${NVSHMEM_IBGDA_SUPPORT:-0}" +export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY="${NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY:-0}" +export NVSHMEM_IBDEVX_SUPPORT="${NVSHMEM_IBDEVX_SUPPORT:-0}" +export NVSHMEM_IBRC_SUPPORT="${NVSHMEM_IBRC_SUPPORT:-1}" +export NVSHMEM_LIBFABRIC_SUPPORT="${NVSHMEM_LIBFABRIC_SUPPORT:-0}" +export NVSHMEM_MPI_SUPPORT="${NVSHMEM_MPI_SUPPORT:-1}" +export NVSHMEM_USE_GDRCOPY="${NVSHMEM_USE_GDRCOPY:-0}" +export NVSHMEM_TORCH_SUPPORT="${NVSHMEM_TORCH_SUPPORT:-1}" +export NVSHMEM_ENABLE_ALL_DEVICE_INLINING="${NVSHMEM_ENABLE_ALL_DEVICE_INLINING:-1}" +[[ -z "${ARCH}" ]] || export CMAKE_CUDA_ARCHITECTURES="${ARCH}" + +cd "${NVSHMEM_SRC}" +mkdir -p build && cd build +CMAKE="${CMAKE:-cmake}" +[[ -f CMakeCache.txt ]] || ${CMAKE} .. \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ + ${ARCH:+-DCMAKE_CUDA_ARCHITECTURES=${ARCH}} \ + -DNVSHMEM_BUILD_TESTS=OFF \ + -DNVSHMEM_BUILD_EXAMPLES=OFF \ + -DNVSHMEM_BUILD_PACKAGES=OFF + +make -j"${JOBS}" VERBOSE=1 + +echo "" +echo "NVSHMEM built successfully at: ${NVSHMEM_SRC}" +echo "" +echo "To use NVSHMEM, add to your environment (e.g. in ~/.bashrc or before running examples):" +echo " export NVSHMEM_SRC=\"${NVSHMEM_SRC}\"" +echo " export LD_LIBRARY_PATH=\"${NVSHMEM_SRC}/build/src/lib:\$LD_LIBRARY_PATH\"" +echo "" diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 44557ec3a..294f8e02f 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -86,8 +86,9 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None): "-I" + tl_template_path, "-I" + cutlass_path, ] - # Add NVSHMEM include path and library linking for distributed support - if env.USE_DISTRIBUTED and env.USE_NVSHMEM: + + # Add NVSHMEM include path and library linking when NVSHMEM is enabled. + if env.USE_NVSHMEM: if env.NVSHMEM_INCLUDE_DIR and env.NVSHMEM_LIB_PATH: options.append("-I" + env.NVSHMEM_INCLUDE_DIR) options.append("-L" + env.NVSHMEM_LIB_PATH) @@ -95,7 +96,7 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None): options.append("-rdc=true") else: raise ValueError( - "TILELANG_USE_DISTRIBUTED is enabled but NVSHMEM paths not found. Install nvidia-nvshmem-cu12 via pip or set NVSHMEM_SRC." + "TILELANG_USE_NVSHMEM is enabled but NVSHMEM paths not found. Install nvidia-nvshmem-cu12 via pip or set NVSHMEM_SRC." ) # Merge extra device compiler flags from pass config, if provided diff --git a/tilelang/env.py b/tilelang/env.py index a15477f04..b99424ab5 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -256,8 +256,10 @@ class Environment: TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) # NVSHMEM paths - auto-detect from pip-installed nvidia-nvshmem-cu12 or NVSHMEM_HOME - _nvshmem_include_dir: str | None = None - _nvshmem_lib_path: str | None = None + # _NVSHMEM_NOT_PROBED is a sentinel indicating that detection has not run yet. + _NVSHMEM_NOT_PROBED = object() + _nvshmem_include_dir: str | None = _NVSHMEM_NOT_PROBED + _nvshmem_lib_path: str | None = _NVSHMEM_NOT_PROBED @property def USE_NVSHMEM(self) -> bool: @@ -271,15 +273,24 @@ def USE_DISTRIBUTED(self) -> bool: @property def NVSHMEM_INCLUDE_DIR(self) -> str | None: - """Get NVSHMEM include directory, auto-detecting if needed.""" - if self._nvshmem_include_dir is None and self.USE_DISTRIBUTED: + """Get NVSHMEM include directory, auto-detecting if needed. + + Path discovery is independent of USE_NVSHMEM / USE_DISTRIBUTED flags; + those flags control whether the paths are *used*, not whether they are + *discovered*. This avoids duplicating detection logic elsewhere. + """ + if self._nvshmem_include_dir is self._NVSHMEM_NOT_PROBED: self._nvshmem_include_dir, self._nvshmem_lib_path = Environment._find_nvshmem_paths() return self._nvshmem_include_dir @property def NVSHMEM_LIB_PATH(self) -> str | None: - """Get NVSHMEM library path, auto-detecting if needed.""" - if self._nvshmem_lib_path is None and self.USE_DISTRIBUTED: + """Get NVSHMEM library path, auto-detecting if needed. + + See :pyattr:`NVSHMEM_INCLUDE_DIR` for the rationale on unconditional + detection. + """ + if self._nvshmem_lib_path is self._NVSHMEM_NOT_PROBED: self._nvshmem_include_dir, self._nvshmem_lib_path = Environment._find_nvshmem_paths() return self._nvshmem_lib_path diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index 0139be513..20e65377d 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -11,8 +11,7 @@ from tilelang.env import env def _use_nvshmem(): """Check if NVSHMEM is enabled in the environment.""" - val = str(env.USE_NVSHMEM).lower() - return val in ("1", "true", "yes", "on") + return env.USE_NVSHMEM if _use_nvshmem(): import pynvshmem diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 01a546db1..e46c0795b 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -24,14 +24,12 @@ def _use_nvshmem(): """Check if NVSHMEM is enabled in the environment.""" - val = str(env.USE_NVSHMEM).lower() - return val in ("1", "true", "yes", "on") + return env.USE_NVSHMEM def _use_distributed(): """Check if distributed mode is enabled in the environment.""" - val = str(env.USE_DISTRIBUTED).lower() - return val in ("1", "true", "yes", "on") + return env.USE_DISTRIBUTED logger = logging.getLogger(__name__)