Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TE executor: DDP support (PR2408) #80

Merged
merged 10 commits into from
Apr 5, 2024

Conversation

kshitij12345
Copy link
Collaborator

Whenever, the world size is greater than 1, the first TE module in the forward (i.e. the last to execute backward) takes care of syncing the fp8 meta-data state across all processes before the next iteration begins.

During forward, TE takes care of setting up for fp8 meta-data reduction if world_size > 1 for each TE module.
https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/module/base.py#L552-L564

During backward of the first TE module in forward pass (/last in backward pass), it takes care of actually syncing the FP8 meta-data and this is by default synchronous/blocking. (It essentially does the torch.cat of all fp8 state, reduction, and torch.split on the reduced state back to individual buffers)
https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/pytorch/module/base.py#L98-L100

This means there are constraints on re-ordering of the te_linear, see NOTE: TransformerEngine Distributed Ordering Constraint in torch_autograd.py.

Thanks @crcrpar for pointing me towards this.

This PR adds a DDP test for thunder+TE executor compared to PyTorch Eager + TE.

Benchmark numbers on a real a model:

Running examples/llama2.c/train.py on 2 Ada RTX6000
Cmd : torchrun --nproc-per-node=2 train.py

Without TE

95 | loss 8.2919 | lr 4.750000e-05 | 357.21ms | mfu 5.69%
96 | loss 8.3111 | lr 4.800000e-05 | 357.91ms | mfu 5.69%
97 | loss 8.2762 | lr 4.850000e-05 | 356.87ms | mfu 5.69%
98 | loss 8.2394 | lr 4.900000e-05 | 355.05ms | mfu 5.69%
99 | loss 8.2340 | lr 4.950000e-05 | 355.43ms | mfu 5.69%
100 | loss 8.1790 | lr 5.000000e-05 | 355.44ms | mfu 5.69%

With TE

95 | loss 8.3030 | lr 4.750000e-05 | 334.94ms | mfu 6.05%
96 | loss 8.3212 | lr 4.800000e-05 | 335.32ms | mfu 6.05%
97 | loss 8.2859 | lr 4.850000e-05 | 334.99ms | mfu 6.05%
98 | loss 8.2492 | lr 4.900000e-05 | 334.98ms | mfu 6.05%
99 | loss 8.2434 | lr 4.950000e-05 | 335.02ms | mfu 6.05%
100 | loss 8.1892 | lr 5.000000e-05 | 334.47ms | mfu 6.05%
Patch for Benchmark
diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py
index 18290df0..dcb52561 100644
--- a/examples/llama2.c/train.py
+++ b/examples/llama2.c/train.py
@@ -28,6 +28,7 @@ from model import Transformer, ModelArgs
 from torch.distributed import destroy_process_group, init_process_group
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.nn.functional as F
+import transformer_engine.pytorch as te
 
 from tinystories import Task
 from export import model_export
@@ -60,7 +61,7 @@ dropout = 0.0
 # adamw optimizer
 gradient_accumulation_steps = 4  # used to simulate larger batch sizes
 learning_rate = 5e-4  # max learning rate
-max_iters = 100000  # total number of training iterations
+max_iters = 100  # total number of training iterations
 weight_decay = 1e-1
 beta1 = 0.9
 beta2 = 0.95
@@ -212,7 +213,8 @@ if compile == "thunder":
 
     import thunder
     from thunder.executors.sdpaex import sdpa_ex
-    executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
+    from thunder.executors.transformer_engineex import transformer_engine_ex
+    executors = [transformer_engine_ex, sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
 
     eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=execu
tors)
     train_model = thunder.compile(train_model.train(), executors_list=executors)
@@ -316,7 +318,7 @@ while True:
             # I really dislike that this bloats the code and forces us to repeat code
             # looking at the source of that context manager, it just toggles this variable
             train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
-        with ctx:
+        with ctx, te.fp8_autocast():
             logits = train_model(X, Y)
             loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
             loss = loss / gradient_accumulation_steps

@kshitij12345 kshitij12345 changed the title TE executor: DDP support TE executor: DDP support (PR2408) Mar 26, 2024
@kshitij12345
Copy link
Collaborator Author

Blocked on #81

@kshitij12345 kshitij12345 marked this pull request as ready for review March 26, 2024 19:42
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) April 2, 2024 10:15
@kshitij12345
Copy link
Collaborator Author

Ping @t-vi for merging.

thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
thunder/tests/distributed/test_ddp.py Show resolved Hide resolved
thunder/executors/torch_autograd.py Show resolved Hide resolved
Copy link
Collaborator

@tfogal tfogal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The minor nit about environment variables aside, I am concerned about this approach long-term. Wouldn't we also have to do something special for the FSDP transform, then? And couldn't any other arbitrary pass reorder things in such a way that it breaks this implicit requirement?
Perhaps I am misunderstanding something; apologies if so!

If TE is doing communication implicitly and that communication can cause issues, I think we need to have real dialogue with the TE team on this. I wouldn't be surprised if the answer is that we need to take over TE's communication here. Two systems scheduling comms is going to lead us towards a deadlock-prone system.

thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
@kshitij12345
Copy link
Collaborator Author

Wouldn't we also have to do something special for the FSDP transform, then? And couldn't any other arbitrary pass reorder things in such a way that it breaks this implicit requirement?

For now, #74 makes sure that we don't have additional transforms after this. But I agree that this approach is quite fragile and ideally we should have the ability to inform TE when to sync.

If TE is doing communication implicitly and that communication can cause issues, I think we need to have real dialogue with the TE team on this. I wouldn't be surprised if the answer is that we need to take over TE's communication here. Two systems scheduling comms is going to lead us towards a deadlock-prone system.

We have started communication with TE team with a similar request. I will loop you in the communication so that you can elaborate there in case we missed anything.

Thanks for having a look

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and thank you!

thunder/tests/distributed/test_ddp.py Show resolved Hide resolved
@kshitij12345
Copy link
Collaborator Author

@carmocca I think this is ready to merge. The CI failures look unrelated.

@carmocca
Copy link
Contributor

carmocca commented Apr 5, 2024

Once we get past the flakes automerge will do its thing 🙌

@IvanYashchuk IvanYashchuk merged commit aef1f4c into Lightning-AI:main Apr 5, 2024
39 checks passed
@github-actions github-actions bot deleted the te_ddp branch July 6, 2024 00:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants