Skip to content

TransformerEngine : Add test with FSDP (and updates to ddp_wrapper in test_ddp.py) #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 25, 2024
211 changes: 193 additions & 18 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import torch.utils.data as tudata
from torch.distributed import distributed_c10d as c10d
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import always_wrap_policy
from torch.testing import assert_close, make_tensor

import thunder
Expand Down Expand Up @@ -1040,12 +1042,12 @@ def __call__(self, test_stub):
init_method = f"{FILE_SCHEMA}{file_name}"

@wraps(test_stub)
def test_fn(executor, devices, dtype, bucket_size_in_mb=0):
def test_fn(executor, devices, dtype, **kwargs):
world_size = len(devices)
input_data = []

for rank in range(world_size):
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, bucket_size_in_mb)
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, kwargs)
input_data.append(process_data)

ctx = mp.get_context("spawn")
Expand Down Expand Up @@ -1080,7 +1082,8 @@ def error_callback(ex):
# NOTE This assumes that one process will have rank=0 -- could generalize that to root
# TODO Test training, this test just currently tests forward
def _test_native_ddp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucket_size_in_mb = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucket_size_in_mb = kwargs["bucket_size_in_mb"]

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1172,7 +1175,8 @@ def _test_native_ddp_helper(input_data):


def _test_native_fsdp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucketing_strategy = kwargs["fsdp_bucketing_strategy"]

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1260,7 +1264,7 @@ def _test_ddp_transformer_engine(input_data):
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1362,7 +1366,18 @@ def is_same_across_ranks(t):
# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history)
# NOTE: TE forward tensor meta-data sync
# Syncing of FP8 meta-data happens in two step in the forward pass.
# 1. When we enter the fp8_autocast(), all the forward fp8 meta-data
# in global buffer is synced.
# See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/fp8.py#L409-L412
# 2. Post this, in the forward pass of the module in `prepare_forward`,
# we read from the global-buffer the synced meta-data.
# See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/module/base.py#L539-L545
# However, at the end of this forward pass, we have seen new inputs and outputs. Their amax are recorded on
# 0th row of `amax_history` (which will be synced only in the next forward pass).
# So, here we check that every row except for `0` is same.
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:])
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
Expand All @@ -1371,16 +1386,16 @@ def is_same_across_ranks(t):
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)
# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions
return comparison_exceptions


def _test_ddp_transformer_engine_llama_sanity(input_data):
Expand All @@ -1391,7 +1406,7 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
# For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py.
from thunder.tests.llama2_model import Transformer, ModelArgs

init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1453,6 +1468,139 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
return None


def _test_fsdp_transformer_engine(input_data):
# Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))`
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data
thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"]
devicetype = devices.device_from_string(device).devicetype

# Setting LOCAL_RANK is necessary for thunder.distributed.fsdp
with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}, clear=True):
init_per_process_distributed(init_method, devicetype, world_size, rank)
torch.cuda.set_device(rank)

dim = 256
n_iter = 10

class ThunderModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(dim, dim, bias=False)
self.fc2 = torch.nn.Linear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

# Weights
fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda()
fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda()

# Inputs (different input on different rank).
if rank == 0:
x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()
if rank == 1:
x = torch.randn(dim, dim).cuda() * 100

thunder_model = ThunderModel().cuda()
thunder_model.fc1.weight.data = fc1_weight.clone()
thunder_model.fc2.weight.data = fc2_weight.clone()

jit_model = thunder.jit(
thunder.distributed.fsdp(thunder_model, sharding_strategy=thunder_fsdp_strategy),
executors=[
transformer_engine_ex,
]
+ executor.executors_list(),
)

optim = torch.optim.SGD(thunder_model.parameters())

for _ in range(n_iter):
o = jit_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()

class TEModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = TELinear(dim, dim, bias=False)
self.fc2 = TELinear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

te_model = TEModel().cuda()
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy)

optim = torch.optim.SGD(te_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = fsdp_model(x).sum()

o.backward()
optim.step()
optim.zero_grad()

thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2}

fwd_traces = thunder.last_traces(jit_model)

def is_same_across_ranks(t):
t_clone = t.clone()
torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG)
assert_close(t, t_clone)

# Compare the state of the two models.
comparison_exceptions = []
for bound_symbol in fwd_traces[-1].bound_symbols:
if "te_linear" in bound_symbol.sym.name:
thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta
te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta
try:
# fwd tensor history
assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale)
assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history)
# bwd tensor history
assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale)
assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history)

# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
# See NOTE: TE forward tensor meta-data sync
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:])
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
shard_size = int(dim / world_size)
fsdp_te_params = tuple(te_model.parameters())
try:
assert_close(thunder_model.fc1.weight, fsdp_te_params[0].view(shard_size, dim))
assert_close(thunder_model.fc2.weight, fsdp_te_params[1].view(shard_size, dim))
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions


# NOTE This is just a stub, see the NOTE for ddp_wrapper
@instantiate(
dtypes=(thunder.float32,),
Expand All @@ -1472,7 +1620,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
devicetypes=(devices.DeviceType.CUDA,),
decorators=(
pytest.mark.parametrize(
"bucket_size_in_mb",
"fsdp_bucketing_strategy",
(
FSDPBucketingStrategy.NONE,
FSDPBucketingStrategy.LAYER,
Expand All @@ -1482,7 +1630,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
),
)
@ddp_wrapper("test_native_fsdp", _test_native_fsdp_helper)
def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb):
def test_native_fsdp(executor, devices, dtype, fsdp_bucketing_strategy):
pass


Expand Down Expand Up @@ -1531,5 +1679,32 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
# NOTE: ddp_wrapper
pytest.mark.parametrize(
"thunder_fsdp_strategy",
(
FSDPType.ZERO2,
FSDPType.ZERO3,
),
),
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# See NOTE: Setting `NVTE_TORCH_COMPILE`
# NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths
# from environment leading to picking up of incorrect dependencies in the spawned process.
unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}),
),
)
@ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine)
def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy):
pass


if __name__ == "__main__":
common_utils.run_tests()