Skip to content
213 changes: 195 additions & 18 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import torch.utils.data as tudata
from torch.distributed import distributed_c10d as c10d
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import always_wrap_policy
from torch.testing import assert_close, make_tensor

import thunder
Expand Down Expand Up @@ -1040,12 +1042,12 @@ def __call__(self, test_stub):
init_method = f"{FILE_SCHEMA}{file_name}"

@wraps(test_stub)
def test_fn(executor, devices, dtype, bucket_size_in_mb=0):
def test_fn(executor, devices, dtype, **kwargs):
world_size = len(devices)
input_data = []

for rank in range(world_size):
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, bucket_size_in_mb)
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, kwargs)
input_data.append(process_data)

ctx = mp.get_context("spawn")
Expand Down Expand Up @@ -1080,7 +1082,8 @@ def error_callback(ex):
# NOTE This assumes that one process will have rank=0 -- could generalize that to root
# TODO Test training, this test just currently tests forward
def _test_native_ddp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucket_size_in_mb = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucket_size_in_mb = kwargs.get("bucket_size_in_mb", 0)

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1172,7 +1175,8 @@ def _test_native_ddp_helper(input_data):


def _test_native_fsdp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucketing_strategy = kwargs["fsdp_bucketing_strategy"]

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1260,7 +1264,7 @@ def _test_ddp_transformer_engine(input_data):
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1362,7 +1366,18 @@ def is_same_across_ranks(t):
# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history)
# NOTE: TE forward tensor meta-data sync
# Syncing of FP8 meta-data happens in two step in the forward pass.
# 1. When we enter the fp8_autocast(), all the forward fp8 meta-data
# in global buffer is synced.
# See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/fp8.py#L409-L412
# 2. Post this, in the forward pass of the module in `prepare_forward`,
# we read from the global-buffer the synced meta-data.
# See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/module/base.py#L539-L545
# However, at the end of this forward pass, we have seen new inputs and outputs. Their amax are recorded on
# 0th row of `amax_history` (which will be synced only in the next forward pass).
# So, here we check that every row except for `0` is same.
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:])
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
Expand All @@ -1371,16 +1386,16 @@ def is_same_across_ranks(t):
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)
# Compare weights after `n_iters`
try:
assert_close(thunder_model.fc1.weight, te_model.fc1.weight)
assert_close(thunder_model.fc2.weight, te_model.fc2.weight)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions
return comparison_exceptions


def _test_ddp_transformer_engine_llama_sanity(input_data):
Expand All @@ -1391,7 +1406,7 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
# For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py.
from thunder.tests.llama2_model import Transformer, ModelArgs

init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1453,6 +1468,141 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
return None


def _test_fsdp_transformer_engine(input_data):
# Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))`
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data
thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"]
devicetype = devices.device_from_string(device).devicetype

# Setting LOCAL_RANK is necessary for thunder.distributed.fsdp
with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}):
init_per_process_distributed(init_method, devicetype, world_size, rank)
torch.cuda.set_device(rank)

dim = 256
n_iter = 10

class ThunderModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(dim, dim, bias=False)
self.fc2 = torch.nn.Linear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

# Weights
fc1_weight = torch.randn(dim, dim, requires_grad=True, device="cuda")
fc2_weight = torch.randn(dim, dim, requires_grad=True, device="cuda")

# Inputs (different input on different rank).
if rank == 0:
x = torch.arange(dim * dim, dtype=torch.float, device="cuda").view(dim, dim)
if rank == 1:
x = torch.randn(dim, dim, device="cuda") * 100

with torch.device("cuda"):
thunder_model = ThunderModel()
thunder_model.fc1.weight.data = fc1_weight.clone()
thunder_model.fc2.weight.data = fc2_weight.clone()

jit_model = thunder.jit(
thunder.distributed.fsdp(thunder_model, sharding_strategy=thunder_fsdp_strategy),
executors=[
transformer_engine_ex,
]
+ executor.executors_list(),
)

optim = torch.optim.SGD(thunder_model.parameters())

for _ in range(n_iter):
o = jit_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()

class TEModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = TELinear(dim, dim, bias=False)
self.fc2 = TELinear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

with torch.device("cuda"):
te_model = TEModel()
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy)

optim = torch.optim.SGD(te_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = fsdp_model(x).sum()

o.backward()
optim.step()
optim.zero_grad()

thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2}

fwd_traces = thunder.last_traces(jit_model)

def is_same_across_ranks(t):
t_clone = t.clone()
torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG)
assert_close(t, t_clone)

# Compare the state of the two models.
comparison_exceptions = []
for bound_symbol in fwd_traces[-1].bound_symbols:
if "te_linear" in bound_symbol.sym.name:
thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta
te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta
try:
# fwd tensor history
assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale)
assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history)
# bwd tensor history
assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale)
assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history)

# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
# See NOTE: TE forward tensor meta-data sync
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:])
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

# Compare weights after `n_iters`
shard_size = int(dim / world_size)
fsdp_te_params = tuple(te_model.parameters())
try:
assert_close(thunder_model.fc1.weight, fsdp_te_params[0].view(shard_size, dim))
assert_close(thunder_model.fc2.weight, fsdp_te_params[1].view(shard_size, dim))
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions


# NOTE This is just a stub, see the NOTE for ddp_wrapper
@instantiate(
dtypes=(thunder.float32,),
Expand All @@ -1472,7 +1622,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
devicetypes=(devices.DeviceType.CUDA,),
decorators=(
pytest.mark.parametrize(
"bucket_size_in_mb",
"fsdp_bucketing_strategy",
(
FSDPBucketingStrategy.NONE,
FSDPBucketingStrategy.LAYER,
Expand All @@ -1482,7 +1632,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
),
)
@ddp_wrapper("test_native_fsdp", _test_native_fsdp_helper)
def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb):
def test_native_fsdp(executor, devices, dtype, fsdp_bucketing_strategy):
pass


Expand Down Expand Up @@ -1531,5 +1681,32 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
# NOTE: ddp_wrapper
pytest.mark.parametrize(
"thunder_fsdp_strategy",
(
FSDPType.ZERO2,
FSDPType.ZERO3,
),
),
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# See NOTE: Setting `NVTE_TORCH_COMPILE`
# NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths
# from environment leading to picking up of incorrect dependencies in the spawned process.
unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}),
),
)
@ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine)
def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy):
pass


if __name__ == "__main__":
common_utils.run_tests()