Skip to content

Commit

Permalink
TE executor: DDP support (PR2408) (#80)
Browse files Browse the repository at this point in the history
Whenever, the world size is greater than 1, the first TE module in the forward (i.e. the last to execute backward) takes care of syncing the fp8 meta-data state across all processes before the next iteration begins.

During forward, TE takes care of setting up for fp8 meta-data reduction if `world_size > 1` for each TE module.
https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/module/base.py#L552-L564

During backward of the first TE module in forward pass (/last in backward pass), it takes care of actually syncing the FP8 meta-data and this is by default synchronous/blocking. (It essentially does the torch.cat of all fp8 state, reduction, and torch.split on the reduced state back to individual buffers)
https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/pytorch/module/base.py#L98-L100

This means there are constraints on re-ordering of the `te_linear`, see `NOTE: TransformerEngine Distributed Ordering Constraint` in `torch_autograd.py`.

Thanks @crcrpar for pointing me towards this.

This PR adds a DDP test for `thunder+TE executor` compared to `PyTorch Eager + TE`.

**Benchmark numbers on a real a model**:

Running `examples/llama2.c/train.py` on 2 Ada RTX6000
Cmd : `torchrun --nproc-per-node=2 train.py`

Without TE
```
95 | loss 8.2919 | lr 4.750000e-05 | 357.21ms | mfu 5.69%
96 | loss 8.3111 | lr 4.800000e-05 | 357.91ms | mfu 5.69%
97 | loss 8.2762 | lr 4.850000e-05 | 356.87ms | mfu 5.69%
98 | loss 8.2394 | lr 4.900000e-05 | 355.05ms | mfu 5.69%
99 | loss 8.2340 | lr 4.950000e-05 | 355.43ms | mfu 5.69%
100 | loss 8.1790 | lr 5.000000e-05 | 355.44ms | mfu 5.69%
```

With TE
```
95 | loss 8.3030 | lr 4.750000e-05 | 334.94ms | mfu 6.05%
96 | loss 8.3212 | lr 4.800000e-05 | 335.32ms | mfu 6.05%
97 | loss 8.2859 | lr 4.850000e-05 | 334.99ms | mfu 6.05%
98 | loss 8.2492 | lr 4.900000e-05 | 334.98ms | mfu 6.05%
99 | loss 8.2434 | lr 4.950000e-05 | 335.02ms | mfu 6.05%
100 | loss 8.1892 | lr 5.000000e-05 | 334.47ms | mfu 6.05%
```

<details>

<summary> Patch for Benchmark </summary>

```patch
diff --git a/examples/llama2.c/train.py b/examples/llama2.c/train.py
index 18290df0..dcb52561 100644
--- a/examples/llama2.c/train.py
+++ b/examples/llama2.c/train.py
@@ -28,6 +28,7 @@ from model import Transformer, ModelArgs
 from torch.distributed import destroy_process_group, init_process_group
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.nn.functional as F
+import transformer_engine.pytorch as te
 
 from tinystories import Task
 from export import model_export
@@ -60,7 +61,7 @@ dropout = 0.0
 # adamw optimizer
 gradient_accumulation_steps = 4  # used to simulate larger batch sizes
 learning_rate = 5e-4  # max learning rate
-max_iters = 100000  # total number of training iterations
+max_iters = 100  # total number of training iterations
 weight_decay = 1e-1
 beta1 = 0.9
 beta2 = 0.95
@@ -212,7 +213,8 @@ if compile == "thunder":
 
     import thunder
     from thunder.executors.sdpaex import sdpa_ex
-    executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
+    from thunder.executors.transformer_engineex import transformer_engine_ex
+    executors = [transformer_engine_ex, sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]
 
     eval_model = thunder.compile(eval_model.eval(), disable_torch_autograd_support=True, executors_list=execu
tors)
     train_model = thunder.compile(train_model.train(), executors_list=executors)
@@ -316,7 +318,7 @@ while True:
             # I really dislike that this bloats the code and forces us to repeat code
             # looking at the source of that context manager, it just toggles this variable
             train_model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
-        with ctx:
+        with ctx, te.fp8_autocast():
             logits = train_model(X, Y)
             loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=-1)
             loss = loss / gradient_accumulation_steps
```
</details>
  • Loading branch information
kshitij12345 authored Apr 5, 2024
1 parent 888b463 commit aef1f4c
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 0 deletions.
7 changes: 7 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def make_trace(func):
if getattr(compile_data.fn, "use_ddp", False):
bw_extrace = sort_waits(bw_extrace)

# Importing here to avoid cyclical dependencies in future.
from thunder.executors.transformer_engineex import transformer_engine_ex, _rearrange_transformer_engine_linear

if transformer_engine_ex in compile_data.executors_list:
# NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`.
_rearrange_transformer_engine_linear(fw_extrace, bw_extrace)

fw_extrace = del_last_used(fw_extrace)
fw_traces.append(fw_extrace)

Expand Down
88 changes: 88 additions & 0 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,91 @@ def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
execution_transform=_linear_transform,
grad_transform=_linear_grad,
)


def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace):
"""
Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace
so that we match the constraint that first FP8 module being called
in forward is the last FP8 module whose gradient is computed in backward pass.
Implementation:
From the backward trace, we find the `ctx_name` of the last `te_functional_linear_backward`.
Then we iterate the forward trace and find the `te_linear` which produces the `ctx_name`
found above. We move this `te_linear` above the first `te_linear` currently in the fwd_trace.
..note::
We could have also done it such that we find the `ctx_name` for first `te_linear` in forward
and re-order the backward pass.
However, on a real model llama2.c example, I noticed that FusionExecutor can create pseudo dependency.
See the example below.
Details:
TransformerEngine takes care of syncing FP8 meta-data
in distributed setting (if world_size > 1). The way this is handled
is by marking the first FP8 module in forward pass. In the backward pass
of that module (last in FP8 module in backward), it collects all the FP8 state,
this state is concatenated, then synced acorss the processes and then split back
into individual state again.
Implementation of the above is in `prepare_forward` and `_prepare_backward` in
`transformer_engine/pytorch/module/base.py`
This means that in thunder, we can't reorder the first `te_linear` or the last backward.
However, FusionExecutors may reorder them.
This function takes care of rearranging such that adhere to this requirement.
Implementation of `prepare_forward`: https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L501
Implementation of `_prepare_backward : https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L67
Example:
Forward Trace Snippet:
[t22, t26] = nvFusion0(t16, t25)
(t77, ctx_te_2) = te_linear_2(t26, layers_0_attention_wv_weight, None)
(t53, ctx_te_1) = te_linear_1(t26, layers_0_attention_wk_weight, None)
(t29, ctx_te_0) = te_linear_0(t26, layers_0_attention_wq_weight, None)
Backward Trace Snippet (without the `del` for brevity):
NOTE: t6822 is part of nvFusion35 which also produces input for te_functional_linear_backward below it.
(t6821, t6822, _) = te_functional_linear_backward(t6819, (i443, i444, i445), (i446, i447), None, ctx_te_2)
NOTE: `nvFusion35` just does `true_divide(t6822, 2)` and returns it for synchronization.
but it also picks up a few operations which process the input for other `te_functional_linear_backward` below.
[t6823, t6857, t6900] = nvFusion35(f468, f476, i293, i294, i295, i296, i297, i432, i433, i434, i435, i436, t36, t38, t6810, t6812, t6822)
t6901 = torch.reshape(t6900, (i186, i187, i188, i189)) # t6901: "cuda:0 f32[128, 256, 6, 48]"
t6902 = torch.reshape(t6901, (i178, i179, i180)) # t6902: "cuda:0 f32[128, 256, 288]"
t6858 = torch.reshape(t6857, (i325, i326, i327, i328)) # t6858: "cuda:0 f32[128, 256, 6, 48]"
t6859 = torch.reshape(t6858, (i317, i318, i319)) # t6859: "cuda:0 f32[128, 256, 288]"
(t6904, t6905, _) = te_functional_linear_backward(t6902, (i165, i166, i167), (i168, i169), None, ctx_te_0)
(t6861, t6862, _) = te_functional_linear_backward(t6859, (i304, i305, i306), (i307, i308), None, ctx_te_1)
"""
# Get the ctx name for the last `te_functional_linear_backward`.
bwd_bsym_ctx = None
for _, bsym in enumerate(reversed(bw_extrace.bound_symbols)):
if bsym.sym.id == te_functional_linear_backward.id:
bwd_bsym_ctx = bsym.args[-1].name
break

first_sym_idx = None
detected_first_sym_idx = None
# Find the first `te_linear` in forward trace
# and the position of `te_linear` which has the last `ctx_name`
# in backward.
for idx, bsym in enumerate(fw_extrace.bound_symbols):
# Forward symbols are generated on the fly so we don't
# have access here.
# Instead we check for the executor field.
if bsym.sym.executor == transformer_engine_ex:
# Sanity check.
assert "te_linear" in bsym.sym.name
if first_sym_idx is None:
first_sym_idx = idx
if bsym.output[-1].name == bwd_bsym_ctx:
detected_first_sym_idx = idx
break

# If the first `te_linear` is not same as that one that should be
# we move it to be the first one.
if detected_first_sym_idx != first_sym_idx:
# Move the symbol to be the first `te_linear`.
fwd_bsyms = fw_extrace.bound_symbols
sym_to_swap = fwd_bsyms[detected_first_sym_idx]
del fwd_bsyms[detected_first_sym_idx]
fwd_bsyms.insert(first_sym_idx, sym_to_swap)
254 changes: 254 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
from thunder.tests.framework import TorchExecutor, nvFuserExecutor
from thunder.tests.framework import instantiate

from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE

is_fp8_supported: bool = False
# This will be correctly updated below when TE Engine is installed
# and if the current environment doesn't support FP8.
fp8_support_reason: str = ""
if TE_AVAILABLE:
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch import Linear as TELinear
from transformer_engine.pytorch.fp8 import check_fp8_support

is_fp8_supported, fp8_support_reason = check_fp8_support()

try:
import expecttest # noqa: F401
import hypothesis # noqa: F401
Expand Down Expand Up @@ -1241,6 +1254,206 @@ def finalize_pg(pg):
return None


def _test_ddp_transformer_engine(input_data):
# Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))`
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)

torch.cuda.set_device(rank)

dim = 256
n_iter = 10

class ThunderModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(dim, dim, bias=False)
self.fc2 = torch.nn.Linear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

# Weights
fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda()
fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda()

# Inputs (different input on different rank).
if rank == 0:
x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()
if rank == 1:
x = torch.randn(dim, dim).cuda() * 100

thunder_model = ThunderModel().cuda()
thunder_model.fc1.weight.data = fc1_weight.clone()
thunder_model.fc2.weight.data = fc2_weight.clone()

jit_model = thunder.jit(
thunder.distributed.ddp(thunder_model),
executors=[
transformer_engine_ex,
]
+ executor.executors_list(),
)

optim = torch.optim.SGD(thunder_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = jit_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()

class TEModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = TELinear(dim, dim, bias=False)
self.fc2 = TELinear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

te_model = TEModel().cuda()
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

ddp_model = DDP(te_model)

optim = torch.optim.SGD(te_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = ddp_model(x).sum()

o.backward()
optim.step()
optim.zero_grad()

thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2}

fwd_traces = thunder.last_traces(jit_model)

def is_same_across_ranks(t):
t_clone = t.clone()
torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG)
assert_close(t, t_clone)

# Compare the state of the two models.
comparison_exceptions = []
for bound_symbol in fwd_traces[-1].bound_symbols:
if "te_linear" in bound_symbol.sym.name:
thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta
te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta
try:
# fwd tensor history
assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale)
assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history)
# bwd tensor history
assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale)
assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history)

# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions


def _test_ddp_transformer_engine_llama_sanity(input_data):
# Test Description: We run a dummy training loop for a Transformer Model
# We run a few iterations to see that TransformerEngine doesn't throw internal assertion
# due to reordering of forward and backward operators.
# (This test will fail without `_rearrange_transformer_engine_linear` in `torch_autograd.py`)
# For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py.
from thunder.tests.llama2_model import Transformer, ModelArgs

init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)

torch.cuda.set_device(rank)
# data
batch_size = 2
max_seq_len = 32
vocab_size = 32

model_args = dict(
dim=32,
n_layers=1,
n_heads=2,
n_kv_heads=2,
vocab_size=vocab_size,
multiple_of=32,
max_seq_len=max_seq_len,
dropout=0.0,
)
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)
model.to(device)
x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device)
y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device)
jit_model = thunder.jit(
thunder.distributed.ddp(model), executors=(transformer_engine_ex,) + thunder.get_default_executors()
)

sanity_exceptions = []
try:
for _ in range(5):
with fp8_autocast():
out = jit_model(x, y).sum()
out.backward()

fwd_exec_trace = thunder.last_traces(jit_model)[-1]
bwd_exec_trace = thunder.last_backward_traces(jit_model)[-1]

# Verify that the first te_linear in fwd_exec_trace is the
# last one in bwd_exec_tarce.
# We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is
# passed to backward.
# As CollectionProxy don't implement __eq__, we verify them by name.
first_ctx_name = None
for bsym in fwd_exec_trace.bound_symbols:
if bsym.sym.name.startswith("te_linear"):
first_ctx_name = bsym.output[1].name
break

for bsym in reversed(bwd_exec_trace.bound_symbols):
if bsym.sym.name.startswith("te_functional"):
assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name)
break
except Exception as e:
sanity_exceptions.append(e)

if rank == 0:
return sanity_exceptions
return None


# NOTE This is just a stub, see the NOTE for ddp_wrapper
@instantiate(
dtypes=(thunder.float32,),
Expand Down Expand Up @@ -1274,5 +1487,46 @@ def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# NOTE: Setting `NVTE_TORCH_COMPILE`
# It is important to set this flag so that TE doesn't use
# `torch.compile` to fuse a few operations. This is because
# `torch.compile` creates a new process and that leads to
# the error : daemonic processes are not allowed to have children
# when running the tests.
# With the setting below, we use `torch.jit` for this test suite
# See: https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/jit.py#L11-L19
unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}, clear=True),
),
)
@ddp_wrapper("test_ddp_transformer_engine", _test_ddp_transformer_engine)
def test_ddp_transformer_engine(executor, devices, dtype):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# See NOTE: Setting `NVTE_TORCH_COMPILE`
unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}, clear=True),
),
)
@ddp_wrapper("test_ddp_transformer_engine_llama_sanity", _test_ddp_transformer_engine_llama_sanity)
def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
pass


if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit aef1f4c

Please sign in to comment.