-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TE executor: DDP support (PR2408) (#80)
Whenever, the world size is greater than 1, the first TE module in the forward (i.e. the last to execute backward) takes care of syncing the fp8 meta-data state across all processes before the next iteration begins. During forward, TE takes care of setting up for fp8 meta-data reduction if `world_size > 1` for each TE module. https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/module/base.py#L552-L564 During backward of the first TE module in forward pass (/last in backward pass), it takes care of actually syncing the FP8 meta-data and this is by default synchronous/blocking. (It essentially does the torch.cat of all fp8 state, reduction, and torch.split on the reduced state back to individual buffers) https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/pytorch/module/base.py#L98-L100 This means there are constraints on re-ordering of the `te_linear`, see `NOTE: TransformerEngine Distributed Ordering Constraint` in `torch_autograd.py`. Thanks @crcrpar for pointing me towards this. This PR adds a DDP test for `thunder+TE executor` compared to `PyTorch Eager + TE`. **Benchmark numbers on a real a model**: Running `examples/llama2.c/train.py` on 2 Ada RTX6000 Cmd : `torchrun --nproc-per-node=2 train.py` Without TE ``` 95 | loss 8.2919 | lr 4.750000e-05 | 357.21ms | mfu 5.69% 96 | loss 8.3111 | lr 4.800000e-05 | 357.91ms | mfu 5.69% 97 | loss 8.2762 | lr 4.850000e-05 | 356.87ms | mfu 5.69% 98 | loss 8.2394 | lr 4.900000e-05 | 355.05ms | mfu 5.69% 99 | loss 8.2340 | lr 4.950000e-05 | 355.43ms | mfu 5.69% 100 | loss 8.1790 | lr 5.000000e-05 | 355.44ms | mfu 5.69% ``` With TE ``` 95 | loss 8.3030 | lr 4.750000e-05 | 334.94ms | mfu 6.05% 96 | loss 8.3212 | lr 4.800000e-05 | 335.32ms | mfu 6.05% 97 | loss 8.2859 | lr 4.850000e-05 | 334.99ms | mfu 6.05% 98 | loss 8.2492 | lr 4.900000e-05 | 334.98ms | mfu 6.05% 99 | loss 8.2434 | lr 4.950000e-05 | 335.02ms | mfu 6.05% 100 | loss 8.1892 | lr 5.000000e-05 | 334.47ms | mfu 6.05% ``` <details> <summary> Patch for Benchmark </summary> ```patch diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py index 18290df0..dcb52561 100644 --- a/examples/llama2.c/train.py +++ b/examples/llama2.c/train.py @@ -28,6 +28,7 @@ from model import Transformer, ModelArgs from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP import torch.nn.functional as F +import transformer_engine.pytorch as te from tinystories import Task from export import model_export @@ -60,7 +61,7 @@ dropout = 0.0 # adamw optimizer gradient_accumulation_steps = 4 # used to simulate larger batch sizes learning_rate = 5e-4 # max learning rate -max_iters = 100000 # total number of training iterations +max_iters = 100 # total number of training iterations weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 @@ -212,7 +213,8 @@ if compile == "thunder": import thunder from thunder.executors.sdpaex import sdpa_ex - executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] + from thunder.executors.transformer_engineex import transformer_engine_ex + executors = [transformer_engine_ex, sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor] eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=execu tors) train_model = thunder.compile(train_model.train(), executors_list=executors) @@ -316,7 +318,7 @@ while True: # I really dislike that this bloats the code and forces us to repeat code # looking at the source of that context manager, it just toggles this variable train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 - with ctx: + with ctx, te.fp8_autocast(): logits = train_model(X, Y) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1) loss = loss / gradient_accumulation_steps ``` </details>
- Loading branch information
1 parent
888b463
commit aef1f4c
Showing
3 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters