Skip to content

Commit d75bf43

Browse files
authored
[JAX] CollectiveGemm (#2166)
* init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 4d14578 commit d75bf43

File tree

24 files changed

+2385
-97
lines changed

24 files changed

+2385
-97
lines changed

build_tools/jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,5 @@ def setup_jax_extension(
8787
sources=[str(path) for path in sources],
8888
include_dirs=[str(path) for path in include_dirs],
8989
extra_compile_args=cxx_flags,
90+
libraries=["nccl"],
9091
)
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""Shared functions for the comm_overlap tests"""
5+
6+
import jax.numpy as jnp
7+
import numpy as np
8+
9+
10+
# Add this after your existing imports
11+
def dtype_tols(dtype, rtol=None, atol=None):
12+
"""Expected numerical tolerance for a data type."""
13+
# Return immediately if tolerances are fully specified
14+
if rtol is not None and atol is not None:
15+
return {"rtol": rtol, "atol": atol}
16+
17+
# Default tolerances for common dtypes
18+
if dtype in [jnp.float32, "float32"]:
19+
return {"rtol": 1e-5, "atol": 1e-8}
20+
elif dtype in [jnp.float16, "float16"]:
21+
return {"rtol": 1e-3, "atol": 1e-6}
22+
elif dtype in [jnp.bfloat16, "bfloat16"]:
23+
return {"rtol": 1e-2, "atol": 1e-5}
24+
else:
25+
return {"rtol": 1e-5, "atol": 1e-8}
26+
27+
28+
def assert_allclose(
29+
actual,
30+
desired,
31+
rtol=None,
32+
atol=None,
33+
dtype=None,
34+
**kwargs,
35+
):
36+
"""Check if two tensors are close."""
37+
# Infer data type if needed
38+
if dtype is None:
39+
if isinstance(actual, float):
40+
dtype = "float32"
41+
else:
42+
dtype = actual.dtype
43+
44+
# Determine tolerances
45+
tols = {}
46+
if rtol is None or atol is None:
47+
tols = dtype_tols(dtype)
48+
if rtol is not None:
49+
tols["rtol"] = rtol
50+
if atol is not None:
51+
tols["atol"] = atol
52+
53+
# Cast tensors to fp32
54+
if not isinstance(actual, float):
55+
actual = actual.astype(jnp.float32)
56+
if not isinstance(desired, float):
57+
desired = desired.astype(jnp.float32)
58+
59+
# Check if tensors are close
60+
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
61+
62+
63+
def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
64+
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
65+
diff = jnp.abs(ref_output - gathered_output)
66+
mask = diff > (atol + rtol * jnp.abs(gathered_output))
67+
print(mask.astype(int))
68+
print(jnp.where(mask, diff, 0))
69+
70+
71+
# Shared constants for all tests
72+
DP_AXIS = "data"
73+
TPSP_AXIS = "tensor_sequence"
74+
PARAMS_KEY = "params"
75+
76+
# Shared functions for distributed testing
77+
import argparse
78+
import jax
79+
from jax.experimental import mesh_utils
80+
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
81+
82+
# Global flag to track if distributed has been initialized
83+
_distributed_initialized = False
84+
85+
86+
def _is_distributed_initialized():
87+
"""Check if JAX distributed has been initialized."""
88+
return _distributed_initialized
89+
90+
91+
def _initialize_distributed(args):
92+
"""Initialize JAX distributed with custom arguments."""
93+
global _distributed_initialized
94+
95+
# Check if already initialized
96+
if _distributed_initialized:
97+
return
98+
99+
if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
100+
raise ValueError(
101+
"All distributed initialization arguments are required: "
102+
"--coordinator-address, --num-processes, --process-id"
103+
)
104+
if args.local_device_ids is None:
105+
assert (
106+
args.num_devices_per_process is not None
107+
), "Either local_device_ids or num_devices_per_process must be provided"
108+
# Calculate device range for this process
109+
# Single process single device: each process gets one unique device
110+
# Single process multiple devices: each process gets a unique range of devices
111+
start_device = args.process_id * args.num_devices_per_process
112+
device_range = range(start_device, start_device + args.num_devices_per_process)
113+
global_device_ids_for_this_process = ",".join(map(str, device_range))
114+
else:
115+
# Use explicitly provided global device IDs
116+
global_device_ids_for_this_process = args.local_device_ids
117+
args.num_devices_per_process = len(args.local_device_ids.split(","))
118+
119+
assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
120+
121+
print(
122+
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
123+
f"num_processes={args.num_processes}, process_id={args.process_id}"
124+
)
125+
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
126+
jax.distributed.initialize(
127+
coordinator_address=args.coordinator_address,
128+
num_processes=args.num_processes,
129+
process_id=args.process_id,
130+
local_device_ids=global_device_ids_for_this_process,
131+
)
132+
133+
_distributed_initialized = True
134+
jax.clear_caches()
135+
jax.config.update(
136+
"jax_use_shardy_partitioner", False
137+
) # CollectiveGEMM does not work with Shardy yet
138+
139+
assert jax.local_device_count() == 1, (
140+
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
141+
f" {jax.local_device_count()}"
142+
)
143+
144+
devices_per_process = 1
145+
num_total_devices = args.num_processes
146+
147+
print(
148+
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
149+
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
150+
)
151+
152+
collective_gemm_bootstrap(
153+
num_total_devices=num_total_devices,
154+
num_devices_per_process=devices_per_process,
155+
process_id=args.process_id,
156+
tensor_parallel_size=args.tensor_parallel_size,
157+
)
158+
159+
160+
def _get_dp_and_tp_sizes(args):
161+
num_gpu = args.num_processes * args.num_devices_per_process
162+
if args.tensor_parallel_size is None:
163+
num_gpu_dp = 2 if args.enable_data_parallel else 1
164+
assert (
165+
num_gpu > 1 and num_gpu % num_gpu_dp == 0
166+
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
167+
num_gpu_tp = num_gpu // num_gpu_dp
168+
else:
169+
num_gpu_tp = args.tensor_parallel_size
170+
assert (
171+
num_gpu > 1 and num_gpu % num_gpu_tp == 0
172+
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
173+
num_gpu_dp = num_gpu // num_gpu_tp
174+
return num_gpu_dp, num_gpu_tp
175+
176+
177+
def _create_mesh(args):
178+
"""Create mesh configuration with proper validation."""
179+
num_gpu = args.num_processes * args.num_devices_per_process
180+
assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
181+
num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
182+
183+
print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
184+
185+
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
186+
mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
187+
return mesh
188+
189+
190+
def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
191+
"""Create common argument parser for all collective GEMM tests."""
192+
parser = argparse.ArgumentParser(description=description)
193+
194+
# Distributed initialization arguments
195+
parser.add_argument(
196+
"--coordinator-address",
197+
type=str,
198+
default=None,
199+
help="Coordinator address for distributed initialization",
200+
)
201+
parser.add_argument(
202+
"--num-processes",
203+
type=int,
204+
default=None,
205+
help="Number of processes for distributed initialization",
206+
)
207+
parser.add_argument(
208+
"--process-id", type=int, default=None, help="Process ID for distributed initialization"
209+
)
210+
parser.add_argument(
211+
"--local-device-ids",
212+
type=str,
213+
default=None,
214+
help="Local device IDs for distributed initialization (comma-separated)",
215+
)
216+
parser.add_argument(
217+
"--num-devices-per-process", type=int, default=1, help="Number of devices per process"
218+
)
219+
220+
# Test configuration arguments
221+
parser.add_argument(
222+
"--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
223+
)
224+
parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
225+
parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
226+
parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
227+
parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
228+
parser.add_argument(
229+
"--collective-type",
230+
type=str,
231+
default="all_gather",
232+
choices=["all_gather", "reduce_scatter"],
233+
help="Type of collective operation",
234+
)
235+
parser.add_argument(
236+
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
237+
)
238+
parser.add_argument(
239+
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
240+
)
241+
parser.add_argument(
242+
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
243+
)
244+
245+
return parser
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""config for collective_gemm tests"""
6+
import pytest
7+
8+
9+
def pytest_addoption(parser):
10+
"""Pytest hook for collective_gemm tests"""
11+
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
12+
parser.addoption("--num-processes", action="store", default=1)
13+
parser.addoption("--process-id", action="store", default=0)
14+
parser.addoption("--local-device-ids", action="store", default=None)
15+
16+
17+
@pytest.fixture(autouse=True)
18+
def distributed_args(request):
19+
"""Fixture for querying distributed initialization arguments"""
20+
if request.cls:
21+
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
22+
request.cls.num_processes = int(request.config.getoption("--num-processes"))
23+
request.cls.process_id = int(request.config.getoption("--process-id"))
24+
request.cls.local_device_ids = request.config.getoption("--local-device-ids")
25+
request.cls.num_devices_per_process = (
26+
1
27+
if request.cls.local_device_ids is None
28+
else len(request.cls.local_device_ids.split(","))
29+
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
6+
7+
# Check if NVLINK is supported before running tests
8+
echo "*** Checking NVLINK support***"
9+
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
10+
NVLINK_EXIT_CODE=$?
11+
12+
# Check if command failed OR output indicates no NVLINK
13+
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
14+
echo "NVLINK is not supported on this platform"
15+
echo "Collective GEMM tests require NVLINK connectivity"
16+
echo "SKIPPING all tests"
17+
exit 0
18+
else
19+
echo "NVLINK support detected"
20+
fi
21+
22+
# Define the test files to run
23+
TEST_FILES=(
24+
"test_gemm.py"
25+
"test_dense_grad.py"
26+
"test_layernorm_mlp_grad.py"
27+
)
28+
29+
echo
30+
echo "*** Executing tests in examples/jax/collective_gemm/ ***"
31+
32+
HAS_FAILURE=0 # Global failure flag
33+
PIDS=() # Array to store all process PIDs
34+
35+
# Cleanup function to kill all processes
36+
cleanup() {
37+
for pid in "${PIDS[@]}"; do
38+
if kill -0 "$pid" 2>/dev/null; then
39+
echo "Killing process $pid"
40+
kill -TERM "$pid" 2>/dev/null || true
41+
fi
42+
done
43+
# Wait a bit and force kill if needed
44+
sleep 2
45+
for pid in "${PIDS[@]}"; do
46+
if kill -0 "$pid" 2>/dev/null; then
47+
echo "Force killing process $pid"
48+
kill -KILL "$pid" 2>/dev/null || true
49+
fi
50+
done
51+
}
52+
53+
# Set up signal handlers to cleanup on exit
54+
trap cleanup EXIT INT TERM
55+
56+
# Run each test file across all GPUs
57+
for TEST_FILE in "${TEST_FILES[@]}"; do
58+
echo
59+
echo "=== Starting test file: $TEST_FILE ..."
60+
61+
# Clear PIDs array for this test file
62+
PIDS=()
63+
64+
for i in $(seq 0 $(($NUM_GPUS - 1))); do
65+
# Define output file for logs
66+
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
67+
68+
if [ $i -eq 0 ]; then
69+
# For process 0: show live output AND save to log file using tee
70+
echo "=== Live output from process 0 ==="
71+
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
72+
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
73+
--num-processes=$NUM_GPUS \
74+
--process-id=$i 2>&1 | tee "$LOG_FILE" &
75+
PID=$!
76+
PIDS+=($PID)
77+
else
78+
# For other processes: redirect to log files only
79+
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
80+
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
81+
--num-processes=$NUM_GPUS \
82+
--process-id=$i > "$LOG_FILE" 2>&1 &
83+
PID=$!
84+
PIDS+=($PID)
85+
fi
86+
done
87+
88+
# Wait for all processes to finish
89+
wait
90+
91+
# Check and print the log content from process 0 (now has log file thanks to tee)
92+
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
93+
echo "... $TEST_FILE SKIPPED"
94+
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
95+
echo "... $TEST_FILE FAILED"
96+
HAS_FAILURE=1
97+
else
98+
echo "... $TEST_FILE PASSED"
99+
fi
100+
101+
# Remove the log files after processing them
102+
wait
103+
rm ${TEST_FILE}_gpu_*.log
104+
done
105+
106+
wait
107+
108+
# Final cleanup (trap will also call cleanup on exit)
109+
cleanup
110+
111+
exit $HAS_FAILURE

0 commit comments

Comments
 (0)