diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 05c6a9ed7a..dcd9a1229c 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -16,6 +16,8 @@ import torch.utils.data as tudata from torch.distributed import distributed_c10d as c10d from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp.wrap import always_wrap_policy from torch.testing import assert_close, make_tensor import thunder @@ -1040,12 +1042,12 @@ def __call__(self, test_stub): init_method = f"{FILE_SCHEMA}{file_name}" @wraps(test_stub) - def test_fn(executor, devices, dtype, bucket_size_in_mb=0): + def test_fn(executor, devices, dtype, **kwargs): world_size = len(devices) input_data = [] for rank in range(world_size): - process_data = (init_method, world_size, rank, executor, devices[rank], dtype, bucket_size_in_mb) + process_data = (init_method, world_size, rank, executor, devices[rank], dtype, kwargs) input_data.append(process_data) ctx = mp.get_context("spawn") @@ -1080,7 +1082,8 @@ def error_callback(ex): # NOTE This assumes that one process will have rank=0 -- could generalize that to root # TODO Test training, this test just currently tests forward def _test_native_ddp_helper(input_data): - init_method, world_size, rank, executor, device, dtype, bucket_size_in_mb = input_data + init_method, world_size, rank, executor, device, dtype, kwargs = input_data + bucket_size_in_mb = kwargs.get("bucket_size_in_mb", 0) num_samples = 2 tensor_shape = (2, 2) @@ -1172,7 +1175,8 @@ def _test_native_ddp_helper(input_data): def _test_native_fsdp_helper(input_data): - init_method, world_size, rank, executor, device, dtype, bucketing_strategy = input_data + init_method, world_size, rank, executor, device, dtype, kwargs = input_data + bucketing_strategy = kwargs["fsdp_bucketing_strategy"] num_samples = 2 tensor_shape = (2, 2) @@ -1260,7 +1264,7 @@ def _test_ddp_transformer_engine(input_data): # model with thunder (using TE executor) and with PyTorch eager + TE # and verify that the weights have converged to same value and # fp8 meta state is same after `n_iter`. - init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data + init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data devicetype = devices.device_from_string(device).devicetype _unused_dtype = ltorch.to_torch_dtype(dtype) init_per_process_distributed(init_method, devicetype, world_size, rank) @@ -1362,7 +1366,18 @@ def is_same_across_ranks(t): # This has to be on all ranks so that the computation is not blocked is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale) is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv) - is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history) + # NOTE: TE forward tensor meta-data sync + # Syncing of FP8 meta-data happens in two step in the forward pass. + # 1. When we enter the fp8_autocast(), all the forward fp8 meta-data + # in global buffer is synced. + # See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/fp8.py#L409-L412 + # 2. Post this, in the forward pass of the module in `prepare_forward`, + # we read from the global-buffer the synced meta-data. + # See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/module/base.py#L539-L545 + # However, at the end of this forward pass, we have seen new inputs and outputs. Their amax are recorded on + # 0th row of `amax_history` (which will be synced only in the next forward pass). + # So, here we check that every row except for `0` is same. + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:]) is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale) is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv) is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history) @@ -1371,16 +1386,16 @@ def is_same_across_ranks(t): if rank == 0: comparison_exceptions.append(e) - # Compare weights after `n_iters` - try: - assert_close(thunder_model.fc1.weight, te_model.fc1.weight) - assert_close(thunder_model.fc2.weight, te_model.fc2.weight) - except Exception as e: - # Return exceptions only for rank==0 - if rank == 0: - comparison_exceptions.append(e) + # Compare weights after `n_iters` + try: + assert_close(thunder_model.fc1.weight, te_model.fc1.weight) + assert_close(thunder_model.fc2.weight, te_model.fc2.weight) + except Exception as e: + # Return exceptions only for rank==0 + if rank == 0: + comparison_exceptions.append(e) - return comparison_exceptions + return comparison_exceptions def _test_ddp_transformer_engine_llama_sanity(input_data): @@ -1391,7 +1406,7 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): # For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py. from thunder.tests.llama2_model import Transformer, ModelArgs - init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data + init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data devicetype = devices.device_from_string(device).devicetype _unused_dtype = ltorch.to_torch_dtype(dtype) init_per_process_distributed(init_method, devicetype, world_size, rank) @@ -1453,6 +1468,141 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): return None +def _test_fsdp_transformer_engine(input_data): + # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` + # model with thunder (using TE executor) and with PyTorch eager + TE + # and verify that the weights have converged to same value and + # fp8 meta state is same after `n_iter`. + init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data + thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"] + devicetype = devices.device_from_string(device).devicetype + + # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp + with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}): + init_per_process_distributed(init_method, devicetype, world_size, rank) + torch.cuda.set_device(rank) + + dim = 256 + n_iter = 10 + + class ThunderModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(dim, dim, bias=False) + self.fc2 = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x): + return self.fc2(torch.nn.functional.relu(self.fc1(x))) + + # Weights + fc1_weight = torch.randn(dim, dim, requires_grad=True, device="cuda") + fc2_weight = torch.randn(dim, dim, requires_grad=True, device="cuda") + + # Inputs (different input on different rank). + if rank == 0: + x = torch.arange(dim * dim, dtype=torch.float, device="cuda").view(dim, dim) + if rank == 1: + x = torch.randn(dim, dim, device="cuda") * 100 + + with torch.device("cuda"): + thunder_model = ThunderModel() + thunder_model.fc1.weight.data = fc1_weight.clone() + thunder_model.fc2.weight.data = fc2_weight.clone() + + jit_model = thunder.jit( + thunder.distributed.fsdp(thunder_model, sharding_strategy=thunder_fsdp_strategy), + executors=[ + transformer_engine_ex, + ] + + executor.executors_list(), + ) + + optim = torch.optim.SGD(thunder_model.parameters()) + + for _ in range(n_iter): + o = jit_model(x).sum() + o.backward() + optim.step() + optim.zero_grad() + + class TEModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = TELinear(dim, dim, bias=False) + self.fc2 = TELinear(dim, dim, bias=False) + + def forward(self, x): + return self.fc2(torch.nn.functional.relu(self.fc1(x))) + + with torch.device("cuda"): + te_model = TEModel() + te_model.fc1.weight.data = fc1_weight.clone() + te_model.fc2.weight.data = fc2_weight.clone() + + fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy) + + optim = torch.optim.SGD(te_model.parameters()) + + for _ in range(n_iter): + with fp8_autocast(): + o = fsdp_model(x).sum() + + o.backward() + optim.step() + optim.zero_grad() + + thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2} + + fwd_traces = thunder.last_traces(jit_model) + + def is_same_across_ranks(t): + t_clone = t.clone() + torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG) + assert_close(t, t_clone) + + # Compare the state of the two models. + comparison_exceptions = [] + for bound_symbol in fwd_traces[-1].bound_symbols: + if "te_linear" in bound_symbol.sym.name: + thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta + te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta + try: + # fwd tensor history + assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale) + assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv) + assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history) + # bwd tensor history + assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale) + assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv) + assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history) + + # This has to be on all ranks so that the computation is not blocked + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale) + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv) + # See NOTE: TE forward tensor meta-data sync + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:]) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history) + except Exception as e: + # Return exceptions only for rank==0 + if rank == 0: + comparison_exceptions.append(e) + + # Compare weights after `n_iters` + shard_size = int(dim / world_size) + fsdp_te_params = tuple(te_model.parameters()) + try: + assert_close(thunder_model.fc1.weight, fsdp_te_params[0].view(shard_size, dim)) + assert_close(thunder_model.fc2.weight, fsdp_te_params[1].view(shard_size, dim)) + except Exception as e: + # Return exceptions only for rank==0 + if rank == 0: + comparison_exceptions.append(e) + + return comparison_exceptions + + # NOTE This is just a stub, see the NOTE for ddp_wrapper @instantiate( dtypes=(thunder.float32,), @@ -1472,7 +1622,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb): devicetypes=(devices.DeviceType.CUDA,), decorators=( pytest.mark.parametrize( - "bucket_size_in_mb", + "fsdp_bucketing_strategy", ( FSDPBucketingStrategy.NONE, FSDPBucketingStrategy.LAYER, @@ -1482,7 +1632,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb): ), ) @ddp_wrapper("test_native_fsdp", _test_native_fsdp_helper) -def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb): +def test_native_fsdp(executor, devices, dtype, fsdp_bucketing_strategy): pass @@ -1531,5 +1681,32 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + # NOTE: ddp_wrapper + pytest.mark.parametrize( + "thunder_fsdp_strategy", + ( + FSDPType.ZERO2, + FSDPType.ZERO3, + ), + ), + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + # See NOTE: Setting `NVTE_TORCH_COMPILE` + # NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths + # from environment leading to picking up of incorrect dependencies in the spawned process. + unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), + ), +) +@ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine) +def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy): + pass + + if __name__ == "__main__": common_utils.run_tests()