diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 93c0c2d874..244d86b46c 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -325,6 +325,13 @@ def make_trace(func): if getattr(compile_data.fn, "use_ddp", False): bw_extrace = sort_waits(bw_extrace) + # Importing here to avoid cyclical dependencies in future. + from thunder.executors.transformer_engineex import transformer_engine_ex, _rearrange_transformer_engine_linear + + if transformer_engine_ex in compile_data.executors_list: + # NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`. + _rearrange_transformer_engine_linear(fw_extrace, bw_extrace) + fw_extrace = del_last_used(fw_extrace) fw_traces.append(fw_extrace) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index ba53c3e940..6dc4b2be6a 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -430,3 +430,91 @@ def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: execution_transform=_linear_transform, grad_transform=_linear_grad, ) + + +def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): + """ + Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace + so that we match the constraint that first FP8 module being called + in forward is the last FP8 module whose gradient is computed in backward pass. + + Implementation: + From the backward trace, we find the `ctx_name` of the last `te_functional_linear_backward`. + Then we iterate the forward trace and find the `te_linear` which produces the `ctx_name` + found above. We move this `te_linear` above the first `te_linear` currently in the fwd_trace. + + ..note:: + We could have also done it such that we find the `ctx_name` for first `te_linear` in forward + and re-order the backward pass. + However, on a real model llama2.c example, I noticed that FusionExecutor can create pseudo dependency. + See the example below. + + Details: + TransformerEngine takes care of syncing FP8 meta-data + in distributed setting (if world_size > 1). The way this is handled + is by marking the first FP8 module in forward pass. In the backward pass + of that module (last in FP8 module in backward), it collects all the FP8 state, + this state is concatenated, then synced acorss the processes and then split back + into individual state again. + Implementation of the above is in `prepare_forward` and `_prepare_backward` in + `transformer_engine/pytorch/module/base.py` + This means that in thunder, we can't reorder the first `te_linear` or the last backward. + However, FusionExecutors may reorder them. + This function takes care of rearranging such that adhere to this requirement. + Implementation of `prepare_forward`: https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L501 + Implementation of `_prepare_backward : https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L67 + + Example: + + Forward Trace Snippet: + [t22, t26] = nvFusion0(t16, t25) + (t77, ctx_te_2) = te_linear_2(t26, layers_0_attention_wv_weight, None) + (t53, ctx_te_1) = te_linear_1(t26, layers_0_attention_wk_weight, None) + (t29, ctx_te_0) = te_linear_0(t26, layers_0_attention_wq_weight, None) + + Backward Trace Snippet (without the `del` for brevity): + NOTE: t6822 is part of nvFusion35 which also produces input for te_functional_linear_backward below it. + (t6821, t6822, _) = te_functional_linear_backward(t6819, (i443, i444, i445), (i446, i447), None, ctx_te_2) + NOTE: `nvFusion35` just does `true_divide(t6822, 2)` and returns it for synchronization. + but it also picks up a few operations which process the input for other `te_functional_linear_backward` below. + [t6823, t6857, t6900] = nvFusion35(f468, f476, i293, i294, i295, i296, i297, i432, i433, i434, i435, i436, t36, t38, t6810, t6812, t6822) + t6901 = torch.reshape(t6900, (i186, i187, i188, i189)) # t6901: "cuda:0 f32[128, 256, 6, 48]" + t6902 = torch.reshape(t6901, (i178, i179, i180)) # t6902: "cuda:0 f32[128, 256, 288]" + t6858 = torch.reshape(t6857, (i325, i326, i327, i328)) # t6858: "cuda:0 f32[128, 256, 6, 48]" + t6859 = torch.reshape(t6858, (i317, i318, i319)) # t6859: "cuda:0 f32[128, 256, 288]" + (t6904, t6905, _) = te_functional_linear_backward(t6902, (i165, i166, i167), (i168, i169), None, ctx_te_0) + (t6861, t6862, _) = te_functional_linear_backward(t6859, (i304, i305, i306), (i307, i308), None, ctx_te_1) + """ + # Get the ctx name for the last `te_functional_linear_backward`. + bwd_bsym_ctx = None + for _, bsym in enumerate(reversed(bw_extrace.bound_symbols)): + if bsym.sym.id == te_functional_linear_backward.id: + bwd_bsym_ctx = bsym.args[-1].name + break + + first_sym_idx = None + detected_first_sym_idx = None + # Find the first `te_linear` in forward trace + # and the position of `te_linear` which has the last `ctx_name` + # in backward. + for idx, bsym in enumerate(fw_extrace.bound_symbols): + # Forward symbols are generated on the fly so we don't + # have access here. + # Instead we check for the executor field. + if bsym.sym.executor == transformer_engine_ex: + # Sanity check. + assert "te_linear" in bsym.sym.name + if first_sym_idx is None: + first_sym_idx = idx + if bsym.output[-1].name == bwd_bsym_ctx: + detected_first_sym_idx = idx + break + + # If the first `te_linear` is not same as that one that should be + # we move it to be the first one. + if detected_first_sym_idx != first_sym_idx: + # Move the symbol to be the first `te_linear`. + fwd_bsyms = fw_extrace.bound_symbols + sym_to_swap = fwd_bsyms[detected_first_sym_idx] + del fwd_bsyms[detected_first_sym_idx] + fwd_bsyms.insert(first_sym_idx, sym_to_swap) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 03bedfd334..49592d78f1 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -28,6 +28,27 @@ from thunder.tests.framework import TorchExecutor, nvFuserExecutor from thunder.tests.framework import instantiate +# It is important to set this flag so that TE doesn't use +# `torch.compile` to fuse a few operations. This is because +# `torch.compile` creates a new process and that leads to +# the error : daemonic processes are not allowed to have children +# when running the tests. +# With the setting below, we use `torch.jit` for this test suite +# See: https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/jit.py#L11-L19 +os.environ["NVTE_TORCH_COMPILE"] = "0" +from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE + +is_fp8_supported: bool = False +# This will be correctly updated below when TE Engine is installed +# and if the current environment doesn't support FP8. +fp8_support_reason: str = "" +if TE_AVAILABLE: + from transformer_engine.pytorch import fp8_autocast + from transformer_engine.pytorch import Linear as TELinear + from transformer_engine.pytorch.fp8 import check_fp8_support + + is_fp8_supported, fp8_support_reason = check_fp8_support() + try: import expecttest # noqa: F401 import hypothesis # noqa: F401 @@ -1213,6 +1234,187 @@ def finalize_pg(pg): return None +def _test_ddp_transformer_engine(input_data): + # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` + # model with thunder (using TE executor) and with PyTorch eager + TE + # and verify that the weights have converged to same value and + # fp8 meta state is same after `n_iter`. + init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data + devicetype = devices.device_from_string(device).devicetype + _unused_dtype = ltorch.to_torch_dtype(dtype) + init_per_process_distributed(init_method, devicetype, world_size, rank) + + torch.cuda.set_device(rank) + + dim = 256 + n_iter = 10 + + class ThunderModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(dim, dim, bias=False) + self.fc2 = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x): + return self.fc2(torch.nn.functional.relu(self.fc1(x))) + + # Weights + fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda() + fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda() + + # Inputs (different input on different rank). + if rank == 0: + x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda() + if rank == 1: + x = torch.randn(dim, dim).cuda() * 100 + + thunder_model = ThunderModel().cuda() + thunder_model.fc1.weight.data = fc1_weight.clone() + thunder_model.fc2.weight.data = fc2_weight.clone() + + jit_model = thunder.jit( + thunder.distributed.ddp(thunder_model), + executors=[ + transformer_engine_ex, + ] + + executor.executors_list(), + ) + + optim = torch.optim.SGD(thunder_model.parameters()) + + for _ in range(n_iter): + with fp8_autocast(): + o = jit_model(x).sum() + o.backward() + optim.step() + optim.zero_grad() + + class TEModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = TELinear(dim, dim, bias=False) + self.fc2 = TELinear(dim, dim, bias=False) + + def forward(self, x): + return self.fc2(torch.nn.functional.relu(self.fc1(x))) + + te_model = TEModel().cuda() + te_model.fc1.weight.data = fc1_weight.clone() + te_model.fc2.weight.data = fc2_weight.clone() + + ddp_model = DDP(te_model) + + optim = torch.optim.SGD(te_model.parameters()) + + for _ in range(n_iter): + with fp8_autocast(): + o = ddp_model(x).sum() + + o.backward() + optim.step() + optim.zero_grad() + + thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2} + + fwd_traces = thunder.last_traces(jit_model) + + def is_same_across_ranks(t): + t_clone = t.clone() + torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG) + assert_close(t, t_clone) + + # Compare the state of the two models. + comparison_exceptions = [] + for bound_symbol in fwd_traces[-1].bound_symbols: + if "te_linear" in bound_symbol.sym.name: + thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta + te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta + try: + # fwd tensor history + assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale) + assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv) + assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history) + # bwd tensor history + assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale) + assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv) + assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history) + + # This has to be on all ranks so that the computation is not blocked + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale) + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv) + is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv) + is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history) + except Exception as e: + # Return exceptions only for rank==0 + if rank == 0: + comparison_exceptions.append(e) + + # Compare weights after `n_iters` + try: + assert_close(thunder_model.fc1.weight, te_model.fc1.weight) + assert_close(thunder_model.fc2.weight, te_model.fc2.weight) + except Exception as e: + # Return exceptions only for rank==0 + if rank == 0: + comparison_exceptions.append(e) + + return comparison_exceptions + + +def _test_ddp_transformer_engine_llama_sanity(input_data): + # Test Description: We run a dummy training loop for a Transformer Model + # We run a few iterations to see that TransformerEngine doesn't throw internal assertion + # due to reordering of forward and backward operators. + # (This test will fail without `_rearrange_transformer_engine_linear` in `torch_autograd.py`) + # For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py. + from thunder.tests.llama2_model import Transformer, ModelArgs + + init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data + devicetype = devices.device_from_string(device).devicetype + _unused_dtype = ltorch.to_torch_dtype(dtype) + init_per_process_distributed(init_method, devicetype, world_size, rank) + + torch.cuda.set_device(rank) + # data + batch_size = 2 + max_seq_len = 32 + vocab_size = 32 + + model_args = dict( + dim=32, + n_layers=1, + n_heads=2, + n_kv_heads=2, + vocab_size=vocab_size, + multiple_of=32, + max_seq_len=max_seq_len, + dropout=0.0, + ) + gptconf = ModelArgs(**model_args) + model = Transformer(gptconf) + model.to(device) + x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) + y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) + jit_model = thunder.jit( + thunder.distributed.ddp(model), executors=(transformer_engine_ex,) + thunder.get_default_executors() + ) + + sanity_exceptions = [] + try: + for _ in range(5): + with fp8_autocast(): + out = jit_model(x, y).sum() + out.backward() + except Exception as e: + sanity_exceptions.append(e) + + if rank == 0: + return sanity_exceptions + return None + + # NOTE This is just a stub, see the NOTE for ddp_wrapper @instantiate( dtypes=(thunder.float32,), @@ -1246,5 +1448,35 @@ def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb): pass +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + ), +) +@ddp_wrapper("test_ddp_transformer_engine", _test_ddp_transformer_engine) +def test_ddp_transformer_engine(executor, devices, dtype): + pass + + +@instantiate( + dtypes=(thunder.float32,), + num_devices=2, + devicetypes=(devices.DeviceType.CUDA,), + executors=(TorchExecutor,), + decorators=( + pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), + pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), + ), +) +@ddp_wrapper("test_ddp_transformer_engine_llama_sanity", _test_ddp_transformer_engine_llama_sanity) +def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): + pass + + if __name__ == "__main__": common_utils.run_tests()