|
| 1 | +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | +"""Shared functions for the comm_overlap tests""" |
| 5 | + |
| 6 | +import jax.numpy as jnp |
| 7 | +import numpy as np |
| 8 | + |
| 9 | + |
| 10 | +# Add this after your existing imports |
| 11 | +def dtype_tols(dtype, rtol=None, atol=None): |
| 12 | + """Expected numerical tolerance for a data type.""" |
| 13 | + # Return immediately if tolerances are fully specified |
| 14 | + if rtol is not None and atol is not None: |
| 15 | + return {"rtol": rtol, "atol": atol} |
| 16 | + |
| 17 | + # Default tolerances for common dtypes |
| 18 | + if dtype in [jnp.float32, "float32"]: |
| 19 | + return {"rtol": 1e-5, "atol": 1e-8} |
| 20 | + elif dtype in [jnp.float16, "float16"]: |
| 21 | + return {"rtol": 1e-3, "atol": 1e-6} |
| 22 | + elif dtype in [jnp.bfloat16, "bfloat16"]: |
| 23 | + return {"rtol": 1e-2, "atol": 1e-5} |
| 24 | + else: |
| 25 | + return {"rtol": 1e-5, "atol": 1e-8} |
| 26 | + |
| 27 | + |
| 28 | +def assert_allclose( |
| 29 | + actual, |
| 30 | + desired, |
| 31 | + rtol=None, |
| 32 | + atol=None, |
| 33 | + dtype=None, |
| 34 | + **kwargs, |
| 35 | +): |
| 36 | + """Check if two tensors are close.""" |
| 37 | + # Infer data type if needed |
| 38 | + if dtype is None: |
| 39 | + if isinstance(actual, float): |
| 40 | + dtype = "float32" |
| 41 | + else: |
| 42 | + dtype = actual.dtype |
| 43 | + |
| 44 | + # Determine tolerances |
| 45 | + tols = {} |
| 46 | + if rtol is None or atol is None: |
| 47 | + tols = dtype_tols(dtype) |
| 48 | + if rtol is not None: |
| 49 | + tols["rtol"] = rtol |
| 50 | + if atol is not None: |
| 51 | + tols["atol"] = atol |
| 52 | + |
| 53 | + # Cast tensors to fp32 |
| 54 | + if not isinstance(actual, float): |
| 55 | + actual = actual.astype(jnp.float32) |
| 56 | + if not isinstance(desired, float): |
| 57 | + desired = desired.astype(jnp.float32) |
| 58 | + |
| 59 | + # Check if tensors are close |
| 60 | + np.testing.assert_allclose(actual, desired, **tols, **kwargs) |
| 61 | + |
| 62 | + |
| 63 | +def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): |
| 64 | + if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): |
| 65 | + diff = jnp.abs(ref_output - gathered_output) |
| 66 | + mask = diff > (atol + rtol * jnp.abs(gathered_output)) |
| 67 | + print(mask.astype(int)) |
| 68 | + print(jnp.where(mask, diff, 0)) |
| 69 | + |
| 70 | + |
| 71 | +# Shared constants for all tests |
| 72 | +DP_AXIS = "data" |
| 73 | +TPSP_AXIS = "tensor_sequence" |
| 74 | +PARAMS_KEY = "params" |
| 75 | + |
| 76 | +# Shared functions for distributed testing |
| 77 | +import argparse |
| 78 | +import jax |
| 79 | +from jax.experimental import mesh_utils |
| 80 | +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap |
| 81 | + |
| 82 | +# Global flag to track if distributed has been initialized |
| 83 | +_distributed_initialized = False |
| 84 | + |
| 85 | + |
| 86 | +def _is_distributed_initialized(): |
| 87 | + """Check if JAX distributed has been initialized.""" |
| 88 | + return _distributed_initialized |
| 89 | + |
| 90 | + |
| 91 | +def _initialize_distributed(args): |
| 92 | + """Initialize JAX distributed with custom arguments.""" |
| 93 | + global _distributed_initialized |
| 94 | + |
| 95 | + # Check if already initialized |
| 96 | + if _distributed_initialized: |
| 97 | + return |
| 98 | + |
| 99 | + if args.coordinator_address is None or args.num_processes is None or args.process_id is None: |
| 100 | + raise ValueError( |
| 101 | + "All distributed initialization arguments are required: " |
| 102 | + "--coordinator-address, --num-processes, --process-id" |
| 103 | + ) |
| 104 | + if args.local_device_ids is None: |
| 105 | + assert ( |
| 106 | + args.num_devices_per_process is not None |
| 107 | + ), "Either local_device_ids or num_devices_per_process must be provided" |
| 108 | + # Calculate device range for this process |
| 109 | + # Single process single device: each process gets one unique device |
| 110 | + # Single process multiple devices: each process gets a unique range of devices |
| 111 | + start_device = args.process_id * args.num_devices_per_process |
| 112 | + device_range = range(start_device, start_device + args.num_devices_per_process) |
| 113 | + global_device_ids_for_this_process = ",".join(map(str, device_range)) |
| 114 | + else: |
| 115 | + # Use explicitly provided global device IDs |
| 116 | + global_device_ids_for_this_process = args.local_device_ids |
| 117 | + args.num_devices_per_process = len(args.local_device_ids.split(",")) |
| 118 | + |
| 119 | + assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" |
| 120 | + |
| 121 | + print( |
| 122 | + f"Initializing JAX distributed with coordinator={args.coordinator_address}, " |
| 123 | + f"num_processes={args.num_processes}, process_id={args.process_id}" |
| 124 | + ) |
| 125 | + # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process" |
| 126 | + jax.distributed.initialize( |
| 127 | + coordinator_address=args.coordinator_address, |
| 128 | + num_processes=args.num_processes, |
| 129 | + process_id=args.process_id, |
| 130 | + local_device_ids=global_device_ids_for_this_process, |
| 131 | + ) |
| 132 | + |
| 133 | + _distributed_initialized = True |
| 134 | + jax.clear_caches() |
| 135 | + jax.config.update( |
| 136 | + "jax_use_shardy_partitioner", False |
| 137 | + ) # CollectiveGEMM does not work with Shardy yet |
| 138 | + |
| 139 | + assert jax.local_device_count() == 1, ( |
| 140 | + f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" |
| 141 | + f" {jax.local_device_count()}" |
| 142 | + ) |
| 143 | + |
| 144 | + devices_per_process = 1 |
| 145 | + num_total_devices = args.num_processes |
| 146 | + |
| 147 | + print( |
| 148 | + f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," |
| 149 | + f" devices_per_process={devices_per_process}, process_id={args.process_id}" |
| 150 | + ) |
| 151 | + |
| 152 | + collective_gemm_bootstrap( |
| 153 | + num_total_devices=num_total_devices, |
| 154 | + num_devices_per_process=devices_per_process, |
| 155 | + process_id=args.process_id, |
| 156 | + tensor_parallel_size=args.tensor_parallel_size, |
| 157 | + ) |
| 158 | + |
| 159 | + |
| 160 | +def _get_dp_and_tp_sizes(args): |
| 161 | + num_gpu = args.num_processes * args.num_devices_per_process |
| 162 | + if args.tensor_parallel_size is None: |
| 163 | + num_gpu_dp = 2 if args.enable_data_parallel else 1 |
| 164 | + assert ( |
| 165 | + num_gpu > 1 and num_gpu % num_gpu_dp == 0 |
| 166 | + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" |
| 167 | + num_gpu_tp = num_gpu // num_gpu_dp |
| 168 | + else: |
| 169 | + num_gpu_tp = args.tensor_parallel_size |
| 170 | + assert ( |
| 171 | + num_gpu > 1 and num_gpu % num_gpu_tp == 0 |
| 172 | + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" |
| 173 | + num_gpu_dp = num_gpu // num_gpu_tp |
| 174 | + return num_gpu_dp, num_gpu_tp |
| 175 | + |
| 176 | + |
| 177 | +def _create_mesh(args): |
| 178 | + """Create mesh configuration with proper validation.""" |
| 179 | + num_gpu = args.num_processes * args.num_devices_per_process |
| 180 | + assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices" |
| 181 | + num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args) |
| 182 | + |
| 183 | + print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)") |
| 184 | + |
| 185 | + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) |
| 186 | + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS)) |
| 187 | + return mesh |
| 188 | + |
| 189 | + |
| 190 | +def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): |
| 191 | + """Create common argument parser for all collective GEMM tests.""" |
| 192 | + parser = argparse.ArgumentParser(description=description) |
| 193 | + |
| 194 | + # Distributed initialization arguments |
| 195 | + parser.add_argument( |
| 196 | + "--coordinator-address", |
| 197 | + type=str, |
| 198 | + default=None, |
| 199 | + help="Coordinator address for distributed initialization", |
| 200 | + ) |
| 201 | + parser.add_argument( |
| 202 | + "--num-processes", |
| 203 | + type=int, |
| 204 | + default=None, |
| 205 | + help="Number of processes for distributed initialization", |
| 206 | + ) |
| 207 | + parser.add_argument( |
| 208 | + "--process-id", type=int, default=None, help="Process ID for distributed initialization" |
| 209 | + ) |
| 210 | + parser.add_argument( |
| 211 | + "--local-device-ids", |
| 212 | + type=str, |
| 213 | + default=None, |
| 214 | + help="Local device IDs for distributed initialization (comma-separated)", |
| 215 | + ) |
| 216 | + parser.add_argument( |
| 217 | + "--num-devices-per-process", type=int, default=1, help="Number of devices per process" |
| 218 | + ) |
| 219 | + |
| 220 | + # Test configuration arguments |
| 221 | + parser.add_argument( |
| 222 | + "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" |
| 223 | + ) |
| 224 | + parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") |
| 225 | + parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") |
| 226 | + parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") |
| 227 | + parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") |
| 228 | + parser.add_argument( |
| 229 | + "--collective-type", |
| 230 | + type=str, |
| 231 | + default="all_gather", |
| 232 | + choices=["all_gather", "reduce_scatter"], |
| 233 | + help="Type of collective operation", |
| 234 | + ) |
| 235 | + parser.add_argument( |
| 236 | + "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" |
| 237 | + ) |
| 238 | + parser.add_argument( |
| 239 | + "--enable-data-parallel", action="store_true", help="Enable data parallelism" |
| 240 | + ) |
| 241 | + parser.add_argument( |
| 242 | + "--enable-result-check", action="store_true", default=True, help="Enable result checking" |
| 243 | + ) |
| 244 | + |
| 245 | + return parser |
0 commit comments