Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions docs/get_started/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- **CUDA Version**: >= 12.1
- **LLVM**: < 20 if you are using the bundled TVM submodule

We currently provide three methods to install **TileScale**:
Install **TileScale** with the following steps:

**(optional)Prepare the container**:

Expand Down Expand Up @@ -48,32 +48,41 @@ You can now run TileScale examples and develop your applications.

**Example Usage:**

You can run TileScale examples:
From the project root:

```bash
cd /home/tilelang
TILELANG_USE_DISTRIBUTED=1 python examples/distributed/example_allgather_gemm_overlapped.py
```

## To use NVSHMEM APIs

Before running the examples using NVSHMEM APIs (e.g., [example_allgather.py](../../examples/distributed/example_allgather.py)), you need to build NVSHMEM library for device-side code generation.
Device-side code generation (kernels calling `nvshmem_*` on the GPU) requires NVSHMEM built from source (the pip package does not provide `libnvshmem_device`). Build from source and install the Python bindings as follows.

**1. Build NVSHMEM from source**

```bash
pip install mpich # building NVSHMEM needs MPI
export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src
pip install mpich # NVSHMEM build requires MPI
cd tilelang/distributed
source build_nvshmem.sh
# Optional: set NVSHMEM_SRC to a custom path; default is ../../3rdparty/nvshmem_src
# For H100 (sm_90), add: bash build_nvshmem.sh --arch 90
bash build_nvshmem.sh
# Then set the env vars printed at the end (NVSHMEM_SRC, LD_LIBRARY_PATH).
```
You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM.

The script downloads the NVSHMEM source tarball from NVIDIA; you may need to be logged in at [NVIDIA Developer](https://developer.nvidia.com) for the download to succeed.

**2. Install pynvshmem (host-side Python API)**

From the project root (ensure `NVSHMEM_SRC` is set, e.g. from step 1 in the same shell):

```bash
cd ./pynvshmem
cd tilelang/distributed/pynvshmem
python setup.py install
export LD_LIBRARY_PATH="$NVSHMEM_SRC/build/src/lib:$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="${NVSHMEM_SRC}/build/src/lib:$LD_LIBRARY_PATH"
```

Then you can test python import:
**3. Verify**

```bash
python -c "import pynvshmem"
```
23 changes: 17 additions & 6 deletions examples/distributed/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,33 @@ For example,

## Prerequisites

Before running the examples, you need to build NVSHMEM library for device-side code generation.
Before running the examples, you need NVSHMEM (either from source or the pip package) and the `pynvshmem` Python bindings.

**Build NVSHMEM from source (from repo root):**

```bash
export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src
pip install mpich
cd tilelang/distributed
source build_nvshmem.sh
bash build_nvshmem.sh # optional: --arch 90 for H100 (sm_90). Then set NVSHMEM_SRC and LD_LIBRARY_PATH as printed.
```
You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM.

**Or install the prebuilt NVSHMEM package:**

```bash
cd ./pynvshmem
pip install nvidia-nvshmem-cu12
```

**Install pynvshmem and set library path:**

```bash
cd tilelang/distributed/pynvshmem
python setup.py install
# If you built NVSHMEM from source:
export LD_LIBRARY_PATH="$NVSHMEM_SRC/build/src/lib:$LD_LIBRARY_PATH"
```

Then you can test python import:
Then verify:

```bash
python -c "import pynvshmem"
```
21 changes: 0 additions & 21 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,6 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <curand_kernel.h>\n";
}

if (use_nvshmem_) {
decl_stream << "#include <nvshmem.h>\n";
decl_stream << "#include <nvshmemx.h>\n";
}

decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
Expand Down Expand Up @@ -2851,22 +2846,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
}
}
os << ")";
} else if (op->op.same_as(tl::Quiet())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
os << "nvshmem_quiet()";
} else if (op->op.same_as(tl::Fence())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
os << "nvshmem_fence()";
} else if (op->op.same_as(tl::SyncAll())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
os << "nvshmem_sync_all()";
} else if (op->op.same_as(tl::BarrierAll())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
os << "nvshmem_barrier_all()";
} else if (op->op.same_as(tl::fence_cta())) {
this->use_distributed_ = true;
os << "tl::memory_fence_cta()";
Expand Down
10 changes: 9 additions & 1 deletion src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@
* Utility function for judging whether distributed mode is enabled.
* This is used to determine whether to include distributed.h in the generated
* code.
*
* Accepted truthy values: "1", "true", "on" (case-insensitive), consistent
* with the Python-side Environment.USE_DISTRIBUTED property.
*/
static inline bool use_distributed() {
const char *env = std::getenv("TILELANG_USE_DISTRIBUTED");
if (env) {
return std::string(env) == "1";
std::string val(env);

// Convert to lowercase for case-insensitive comparison
for (auto &c : val)
c = std::tolower(c);
return val == "1" || val == "true" || val == "on";
}
return false;
}
Expand Down
27 changes: 3 additions & 24 deletions tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,7 @@
import subprocess
import warnings
import contextlib
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH


def _get_nvshmem_include_path():
"""Get NVSHMEM include path from pip-installed nvidia-nvshmem-cu12 or environment."""
# Try pip-installed nvidia-nvshmem-cu12
try:
import nvidia.nvshmem

nvshmem_path = nvidia.nvshmem.__path__[0]
include_path = os.path.join(nvshmem_path, "include")
if os.path.exists(include_path):
return include_path
except ImportError:
pass
# Try environment variable
nvshmem_home = os.environ.get("NVSHMEM_HOME", "")
if nvshmem_home:
include_path = os.path.join(nvshmem_home, "include")
if os.path.exists(include_path):
return include_path
return None
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH, env as _env


import shutil
Expand Down Expand Up @@ -177,8 +156,8 @@ def default_compile_options(compile_flags: list[str] | None = None) -> list[str]
except Exception:
pass

# Add NVSHMEM include path for distributed support
nvshmem_include = _get_nvshmem_include_path()
# Add NVSHMEM include path for distributed support (centralized in env.py)
nvshmem_include = _env.NVSHMEM_INCLUDE_DIR
if nvshmem_include:
options.append(f"-I{nvshmem_include}")

Expand Down
156 changes: 64 additions & 92 deletions tilelang/distributed/build_nvshmem.sh
Original file line number Diff line number Diff line change
@@ -1,103 +1,75 @@
#!/bin/bash
# Build NVSHMEM from source for TileLang device-side use. Run from repo root or set NVSHMEM_SRC.
# Usage: source build_nvshmem.sh [--arch 90] [--jobs N] [--force-download]
Comment on lines +2 to +3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Inconsistent usage guidance: comment says source, help says bash.

Line 2 says "Run from repo root or set NVSHMEM_SRC" and line 3 says source build_nvshmem.sh, but the --help output on line 16 says bash build_nvshmem.sh. The Installation docs also use bash. Since the script uses exit (not return) and changes cwd, it should not be sourced. Update the header comment to match.

-# Usage: source build_nvshmem.sh [--arch 90] [--jobs N] [--force-download]
+# Usage: bash build_nvshmem.sh [--arch 90] [--jobs N] [--force-download]
🤖 Prompt for AI Agents
In `@tilelang/distributed/build_nvshmem.sh` around lines 2 - 3, Update the header
comment in build_nvshmem.sh to instruct users to run the script with "bash
build_nvshmem.sh" (or "./build_nvshmem.sh") instead of "source
build_nvshmem.sh"; the script uses exit and changes the working directory so it
must not be sourced—modify the top-two-line comment that currently reads "Run
from repo root or set NVSHMEM_SRC." and "Usage: source build_nvshmem.sh [--arch
90] [--jobs N] [--force-download]" to reflect the correct invocation (bash) and
keep the existing usage flags unchanged.

# Override at runtime: NVSHMEM_SRC, NVSHMEM_* (see below), CMAKE, NVSHMEM_VERSION.

if [ -z "${NVSHMEM_SRC}" ]; then
export NVSHMEM_SRC="$(realpath ../../3rdparty/nvshmem_src)"
echo "NVSHMEM_SRC not set, defaulting to ${NVSHMEM_SRC}"
else
NVSHMEM_SRC="$(realpath ${NVSHMEM_SRC})"
echo "Using NVSHMEM_SRC=${NVSHMEM_SRC}"
fi

if [ -d "${NVSHMEM_SRC}" ]; then
if [ "$(ls -A ${NVSHMEM_SRC})" ]; then
echo "NVSHMEM_SRC directory (${NVSHMEM_SRC}) is not empty, cleaning it..."
rm -rf "${NVSHMEM_SRC}/"*
rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true
fi
else
mkdir -p "${NVSHMEM_SRC}"
fi

wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz
tar zxvf nvshmem_src_3.2.5-1.txz
rm -rf nvshmem_src_3.2.5-1.txz

mkdir -p "${NVSHMEM_SRC}"
mv nvshmem_src/* "${NVSHMEM_SRC}/"
mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true
rmdir nvshmem_src


export NVSHMEM_PATH="${NVSHMEM_SRC}"

VER="${NVSHMEM_VERSION:-3.2.5-1}"
TARBALL="nvshmem_src_${VER}.txz"
URL="${NVSHMEM_SOURCE_URL:-https://developer.nvidia.com/downloads/assets/secure/nvshmem/${TARBALL}}"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT=$(realpath "${SCRIPT_DIR}")
echo "SCRIPT_DIR: ${SCRIPT_DIR}"
echo "PROJECT_ROOT: ${PROJECT_ROOT}"
echo "NVSHMEM will be installed to: ${NVSHMEM_SRC}"

FORCE_DL=""
ARCH=""
JOBS=""

# Iterate over the command-line arguments
while [[ $# -gt 0 ]]; do
key="$1"

case $key in
--arch)
# Process the arch argument
ARCH="$2"
shift # Skip the argument value
shift # Skip the argument key
;;
--jobs)
# Process the jobs argument
JOBS="$2"
shift # Skip the argument value
shift # Skip the argument key
;;
*)
# Unknown argument
echo "Unknown argument: $1"
shift # Skip the argument
;;
esac
case "$1" in
-h|--help) echo "Usage: bash build_nvshmem.sh [--arch ARCH] [--jobs N] [--force-download]"; exit 0 ;;
--arch) ARCH="$2"; shift 2 ;;
--jobs) JOBS="$2"; shift 2 ;;
--force-download) FORCE_DL=1; shift ;;
*) shift ;;
esac
done

if [[ -n "${ARCH}" ]]; then
export CMAKE_CUDA_ARCHITECTURES="${ARCH}"
CUDAARCH_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${ARCH}"
fi

if [[ -z "${JOBS}" ]]; then
JOBS=$(nproc --ignore 2)
fi

export NVSHMEM_IBGDA_SUPPORT=0
export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY=0
export NVSHMEM_IBDEVX_SUPPORT=0
export NVSHMEM_IBRC_SUPPORT=1
export NVSHMEM_LIBFABRIC_SUPPORT=0
export NVSHMEM_MPI_SUPPORT=1
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_TORCH_SUPPORT=1
export NVSHMEM_ENABLE_ALL_DEVICE_INLINING=1

pushd "${NVSHMEM_SRC}"
mkdir -p build
cd build
CMAKE=${CMAKE:-cmake}

if [ ! -f CMakeCache.txt ]; then
${CMAKE} .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=1 \
${CUDAARCH_ARGS} \
-DNVSHMEM_BUILD_TESTS=OFF \
-DNVSHMEM_BUILD_EXAMPLES=OFF \
-DNVSHMEM_BUILD_PACKAGES=OFF
export NVSHMEM_SRC="$(realpath "${NVSHMEM_SRC:-$SCRIPT_DIR/../../3rdparty/nvshmem_src}")"
export NVSHMEM_PATH="${NVSHMEM_SRC}"
JOBS="${JOBS:-$(nproc 2>/dev/null || echo 4)}"
ARCH="${ARCH:-$CMAKE_CUDA_ARCHITECTURES}"

if [[ -n "$FORCE_DL" ]] || [[ ! -f "${NVSHMEM_SRC}/CMakeLists.txt" ]]; then
if [[ -f "${NVSHMEM_SRC}/CMakeLists.txt" ]]; then
rm -rf "${NVSHMEM_SRC:?}/"*
rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

.* glob matches . and .. — use a safer pattern.

On lines 34 and 44, the .* glob expands to include . and ... While GNU rm/mv usually refuse to operate on these, POSIX doesn't guarantee it, and some environments may behave differently.

Safer glob alternatives
-  rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true
+  rm -rf "${NVSHMEM_SRC}"/.[!.]* "${NVSHMEM_SRC}"/..?* 2>/dev/null || true
-  mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true
+  mv nvshmem_src/.[!.]* nvshmem_src/..?* "${NVSHMEM_SRC}/" 2>/dev/null || true

The pattern .[!.]* matches single-dot-prefixed names and ..?* matches double-dot-prefixed names with at least one more character, safely excluding . and ...

Also applies to: 44-44

🤖 Prompt for AI Agents
In `@tilelang/distributed/build_nvshmem.sh` at line 34, The rm/mv invocations that
use the ".*" glob (referencing NVSHMEM_SRC) may match "." and ".."; replace the
unsafe ".*" usage with two safe globs that exclude "." and ".." (one that
matches names starting with a dot but not dot/dotdot, and one that matches
dot-dot-prefixed names with at least one additional character) in both places
(the rm at the line using NVSHMEM_SRC and the similar mv on the other affected
line) so the commands no longer risk including "." or ".." while preserving
dotfile handling.

else
mkdir -p "${NVSHMEM_SRC}"
fi
cd "${SCRIPT_DIR}"
[[ -f "${TARBALL}" ]] || { wget -q --show-progress "${URL}" -O "${TARBALL}" || { echo "Download failed (login at developer.nvidia.com?)." >&2; exit 1; }; }
tar -zxf "${TARBALL}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Wrong tar decompression flag for .txz (xz) archive.

tar -zxf invokes gzip decompression, but .txz files are xz-compressed. This will fail at extraction time.

Fix
-  tar -zxf "${TARBALL}"
+  tar xf "${TARBALL}"

Using tar xf (no explicit compressor flag) lets modern GNU tar auto-detect the compression format, which works for both .txz and .tar.gz.

🤖 Prompt for AI Agents
In `@tilelang/distributed/build_nvshmem.sh` at line 40, The tar extraction line
incorrectly forces gzip with the -z flag when extracting the archive referenced
by the TARBALL variable; change the tar invocation that currently runs against
"${TARBALL}" to remove the -z flag so tar auto-detects compression (i.e., use a
tar command without the z option) to support .txz (xz) and .tar.gz
transparently.

rm -f "${TARBALL}"
[[ -d nvshmem_src ]] || { echo "Missing nvshmem_src after extract." >&2; exit 1; }
mv nvshmem_src/* "${NVSHMEM_SRC}/"
mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true
rmdir nvshmem_src
fi

make VERBOSE=1 -j"${JOBS}"
popd

echo "NVSHMEM installed successfully to ${NVSHMEM_SRC}"
export NVSHMEM_IBGDA_SUPPORT="${NVSHMEM_IBGDA_SUPPORT:-0}"
export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY="${NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY:-0}"
export NVSHMEM_IBDEVX_SUPPORT="${NVSHMEM_IBDEVX_SUPPORT:-0}"
export NVSHMEM_IBRC_SUPPORT="${NVSHMEM_IBRC_SUPPORT:-1}"
export NVSHMEM_LIBFABRIC_SUPPORT="${NVSHMEM_LIBFABRIC_SUPPORT:-0}"
export NVSHMEM_MPI_SUPPORT="${NVSHMEM_MPI_SUPPORT:-1}"
export NVSHMEM_USE_GDRCOPY="${NVSHMEM_USE_GDRCOPY:-0}"
export NVSHMEM_TORCH_SUPPORT="${NVSHMEM_TORCH_SUPPORT:-1}"
export NVSHMEM_ENABLE_ALL_DEVICE_INLINING="${NVSHMEM_ENABLE_ALL_DEVICE_INLINING:-1}"
[[ -z "${ARCH}" ]] || export CMAKE_CUDA_ARCHITECTURES="${ARCH}"

cd "${NVSHMEM_SRC}"
mkdir -p build && cd build
CMAKE="${CMAKE:-cmake}"
[[ -f CMakeCache.txt ]] || ${CMAKE} .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=1 \
${ARCH:+-DCMAKE_CUDA_ARCHITECTURES=${ARCH}} \
-DNVSHMEM_BUILD_TESTS=OFF \
-DNVSHMEM_BUILD_EXAMPLES=OFF \
-DNVSHMEM_BUILD_PACKAGES=OFF

make -j"${JOBS}" VERBOSE=1

echo ""
echo "NVSHMEM built successfully at: ${NVSHMEM_SRC}"
echo ""
echo "To use NVSHMEM, add to your environment (e.g. in ~/.bashrc or before running examples):"
echo " export NVSHMEM_SRC=\"${NVSHMEM_SRC}\""
echo " export LD_LIBRARY_PATH=\"${NVSHMEM_SRC}/build/src/lib:\$LD_LIBRARY_PATH\""
echo ""
7 changes: 4 additions & 3 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,17 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None):
"-I" + tl_template_path,
"-I" + cutlass_path,
]
# Add NVSHMEM include path and library linking for distributed support
if env.USE_DISTRIBUTED and env.USE_NVSHMEM:

# Add NVSHMEM include path and library linking when NVSHMEM is enabled.
if env.USE_NVSHMEM:
if env.NVSHMEM_INCLUDE_DIR and env.NVSHMEM_LIB_PATH:
options.append("-I" + env.NVSHMEM_INCLUDE_DIR)
options.append("-L" + env.NVSHMEM_LIB_PATH)
options.append("-lnvshmem_device")
options.append("-rdc=true")
else:
raise ValueError(
"TILELANG_USE_DISTRIBUTED is enabled but NVSHMEM paths not found. Install nvidia-nvshmem-cu12 via pip or set NVSHMEM_SRC."
"TILELANG_USE_NVSHMEM is enabled but NVSHMEM paths not found. Install nvidia-nvshmem-cu12 via pip or set NVSHMEM_SRC."
)

# Merge extra device compiler flags from pass config, if provided
Expand Down
Loading
Loading