Skip to content

Commit

Permalink
TE executor: DDP support
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Mar 26, 2024
1 parent bdf5c3f commit 1513724
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 0 deletions.
7 changes: 7 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def make_trace(func):
if getattr(compile_data.fn, "use_ddp", False):
bw_extrace = sort_waits(bw_extrace)

# Importing here to avoid cyclical dependencies in future.
from thunder.executors.transformer_engineex import transformer_engine_ex, _rearrange_transformer_engine_linear

if transformer_engine_ex in compile_data.executors_list:
# NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`.
_rearrange_transformer_engine_linear(fw_extrace, bw_extrace)

fw_extrace = del_last_used(fw_extrace)
fw_traces.append(fw_extrace)

Expand Down
88 changes: 88 additions & 0 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,91 @@ def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
execution_transform=_linear_transform,
grad_transform=_linear_grad,
)


def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace):
"""
Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace
so that we match the constraint that first FP8 module being called
in forward is the last FP8 module whose gradient is computed in backward pass.
Implementation:
From the backward trace, we find the `ctx_name` of the last `te_functional_linear_backward`.
Then we iterate the forward trace and find the `te_linear` which produces the `ctx_name`
found above. We move this `te_linear` above the first `te_linear` currently in the fwd_trace.
..note::
We could have also done it such that we find the `ctx_name` for first `te_linear` in forward
and re-order the backward pass.
However, on a real model llama2.c example, I noticed that FusionExecutor can create pseudo dependency.
See the example below.
Details:
TransformerEngine takes care of syncing FP8 meta-data
in distributed setting (if world_size > 1). The way this is handled
is by marking the first FP8 module in forward pass. In the backward pass
of that module (last in FP8 module in backward), it collects all the FP8 state,
this state is concatenated, then synced acorss the processes and then split back
into individual state again.
Implementation of the above is in `prepare_forward` and `_prepare_backward` in
`transformer_engine/pytorch/module/base.py`
This means that in thunder, we can't reorder the first `te_linear` or the last backward.
However, FusionExecutors may reorder them.
This function takes care of rearranging such that adhere to this requirement.
Implementation of `prepare_forward`: https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L501
Implementation of `_prepare_backward : https://github.com/NVIDIA/TransformerEngine/blob/2d0ab27f/transformer_engine/pytorch/module/base.py#L67
Example:
Forward Trace Snippet:
[t22, t26] = nvFusion0(t16, t25)
(t77, ctx_te_2) = te_linear_2(t26, layers_0_attention_wv_weight, None)
(t53, ctx_te_1) = te_linear_1(t26, layers_0_attention_wk_weight, None)
(t29, ctx_te_0) = te_linear_0(t26, layers_0_attention_wq_weight, None)
Backward Trace Snippet (without the `del` for brevity):
NOTE: t6822 is part of nvFusion35 which also produces input for te_functional_linear_backward below it.
(t6821, t6822, _) = te_functional_linear_backward(t6819, (i443, i444, i445), (i446, i447), None, ctx_te_2)
NOTE: `nvFusion35` just does `true_divide(t6822, 2)` and returns it for synchronization.
but it also picks up a few operations which process the input for other `te_functional_linear_backward` below.
[t6823, t6857, t6900] = nvFusion35(f468, f476, i293, i294, i295, i296, i297, i432, i433, i434, i435, i436, t36, t38, t6810, t6812, t6822)
t6901 = torch.reshape(t6900, (i186, i187, i188, i189)) # t6901: "cuda:0 f32[128, 256, 6, 48]"
t6902 = torch.reshape(t6901, (i178, i179, i180)) # t6902: "cuda:0 f32[128, 256, 288]"
t6858 = torch.reshape(t6857, (i325, i326, i327, i328)) # t6858: "cuda:0 f32[128, 256, 6, 48]"
t6859 = torch.reshape(t6858, (i317, i318, i319)) # t6859: "cuda:0 f32[128, 256, 288]"
(t6904, t6905, _) = te_functional_linear_backward(t6902, (i165, i166, i167), (i168, i169), None, ctx_te_0)
(t6861, t6862, _) = te_functional_linear_backward(t6859, (i304, i305, i306), (i307, i308), None, ctx_te_1)
"""
# Get the ctx name for the last `te_functional_linear_backward`.
bwd_bsym_ctx = None
for _, bsym in enumerate(reversed(bw_extrace.bound_symbols)):
if bsym.sym.id == te_functional_linear_backward.id:
bwd_bsym_ctx = bsym.args[-1].name
break

first_sym_idx = None
detected_first_sym_idx = None
# Find the first `te_linear` in forward trace
# and the position of `te_linear` which has the last `ctx_name`
# in backward.
for idx, bsym in enumerate(fw_extrace.bound_symbols):
# Forward symbols are generated on the fly so we don't
# have access here.
# Instead we check for the executor field.
if bsym.sym.executor == transformer_engine_ex:
# Sanity check.
assert "te_linear" in bsym.sym.name
if first_sym_idx is None:
first_sym_idx = idx
if bsym.output[-1].name == bwd_bsym_ctx:
detected_first_sym_idx = idx
break

# If the first `te_linear` is not same as that one that should be
# we move it to be the first one.
if detected_first_sym_idx != first_sym_idx:
# Move the symbol to be the first `te_linear`.
fwd_bsyms = fw_extrace.bound_symbols
sym_to_swap = fwd_bsyms[detected_first_sym_idx]
del fwd_bsyms[detected_first_sym_idx]
fwd_bsyms.insert(first_sym_idx, sym_to_swap)
232 changes: 232 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@
from thunder.tests.framework import TorchExecutor, nvFuserExecutor
from thunder.tests.framework import instantiate

# It is important to set this flag so that TE doesn't use
# `torch.compile` to fuse a few operations. This is because
# `torch.compile` creates a new process and that leads to
# the error : daemonic processes are not allowed to have children
# when running the tests.
# With the setting below, we use `torch.jit` for this test suite
# See: https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/jit.py#L11-L19
os.environ["NVTE_TORCH_COMPILE"] = "0"
from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE

is_fp8_supported: bool = False
# This will be correctly updated below when TE Engine is installed
# and if the current environment doesn't support FP8.
fp8_support_reason: str = ""
if TE_AVAILABLE:
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch import Linear as TELinear
from transformer_engine.pytorch.fp8 import check_fp8_support

is_fp8_supported, fp8_support_reason = check_fp8_support()

try:
import expecttest # noqa: F401
import hypothesis # noqa: F401
Expand Down Expand Up @@ -1213,6 +1234,187 @@ def finalize_pg(pg):
return None


def _test_ddp_transformer_engine(input_data):
# Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))`
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)

torch.cuda.set_device(rank)

dim = 256
n_iter = 10

class ThunderModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(dim, dim, bias=False)
self.fc2 = torch.nn.Linear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

# Weights
fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda()
fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda()

# Inputs (different input on different rank).
if rank == 0:
x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()
if rank == 1:
x = torch.randn(dim, dim).cuda() * 100

thunder_model = ThunderModel().cuda()
thunder_model.fc1.weight.data = fc1_weight.clone()
thunder_model.fc2.weight.data = fc2_weight.clone()

jit_model = thunder.jit(
thunder.distributed.ddp(thunder_model),
executors=[
transformer_engine_ex,
]
+ executor.executors_list(),
)

optim = torch.optim.SGD(thunder_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = jit_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()

class TEModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = TELinear(dim, dim, bias=False)
self.fc2 = TELinear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

te_model = TEModel().cuda()
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

ddp_model = DDP(te_model)

optim = torch.optim.SGD(te_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = ddp_model(x).sum()

o.backward()
optim.step()
optim.zero_grad()

thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2}

fwd_traces = thunder.last_traces(jit_model)

def is_same_across_ranks(t):
t_clone = t.clone()
torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG)
assert_close(t, t_clone)

# Compare the state of the two models.
comparison_exceptions = []
for bound_symbol in fwd_traces[-1].bound_symbols:
if "te_linear" in bound_symbol.sym.name:
thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta
te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta
try:
# fwd tensor history
assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale)
assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history)
# bwd tensor history
assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale)
assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history)

# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions


def _test_ddp_transformer_engine_llama_sanity(input_data):
# Test Description: We run a dummy training loop for a Transformer Model
# We run a few iterations to see that TransformerEngine doesn't throw internal assertion
# due to reordering of forward and backward operators.
# (This test will fail without `_rearrange_transformer_engine_linear` in `torch_autograd.py`)
# For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py.
from thunder.tests.llama2_model import Transformer, ModelArgs

init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)

torch.cuda.set_device(rank)
# data
batch_size = 2
max_seq_len = 32
vocab_size = 32

model_args = dict(
dim=32,
n_layers=1,
n_heads=2,
n_kv_heads=2,
vocab_size=vocab_size,
multiple_of=32,
max_seq_len=max_seq_len,
dropout=0.0,
)
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)
model.to(device)
x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device)
y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device)
jit_model = thunder.jit(
thunder.distributed.ddp(model), executors=(transformer_engine_ex,) + thunder.get_default_executors()
)

sanity_exceptions = []
try:
for _ in range(5):
with fp8_autocast():
out = jit_model(x, y).sum()
out.backward()
except Exception as e:
sanity_exceptions.append(e)

if rank == 0:
return sanity_exceptions
return None


# NOTE This is just a stub, see the NOTE for ddp_wrapper
@instantiate(
dtypes=(thunder.float32,),
Expand Down Expand Up @@ -1246,5 +1448,35 @@ def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
),
)
@ddp_wrapper("test_ddp_transformer_engine", _test_ddp_transformer_engine)
def test_ddp_transformer_engine(executor, devices, dtype):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
),
)
@ddp_wrapper("test_ddp_transformer_engine_llama_sanity", _test_ddp_transformer_engine_llama_sanity)
def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
pass


if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit 1513724

Please sign in to comment.