Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEngine : Add test with FSDP (and updates to ddp_wrapper in test_ddp.py) #142

Merged
merged 10 commits into from
Apr 25, 2024
178 changes: 170 additions & 8 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import torch.utils.data as tudata
from torch.distributed import distributed_c10d as c10d
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import always_wrap_policy
from torch.testing import assert_close, make_tensor

import thunder
Expand Down Expand Up @@ -1039,12 +1041,12 @@ def __call__(self, test_stub):
init_method = f"{FILE_SCHEMA}{file_name}"

@wraps(test_stub)
def test_fn(executor, devices, dtype, bucket_size_in_mb=0):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
def test_fn(executor, devices, dtype, **kwargs):
world_size = len(devices)
input_data = []

for rank in range(world_size):
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, bucket_size_in_mb)
process_data = (init_method, world_size, rank, executor, devices[rank], dtype, kwargs)
input_data.append(process_data)

ctx = mp.get_context("spawn")
Expand Down Expand Up @@ -1079,7 +1081,8 @@ def error_callback(ex):
# NOTE This assumes that one process will have rank=0 -- could generalize that to root
# TODO Test training, this test just currently tests forward
def _test_native_ddp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucket_size_in_mb = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucket_size_in_mb = kwargs["bucket_size_in_mb"]
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1171,7 +1174,8 @@ def _test_native_ddp_helper(input_data):


def _test_native_fsdp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucketing_strategy = kwargs["fsdp_bucketing_strategy"]

num_samples = 2
tensor_shape = (2, 2)
Expand Down Expand Up @@ -1259,7 +1263,7 @@ def _test_ddp_transformer_engine(input_data):
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1391,7 +1395,7 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
# For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_ex.py.
from thunder.tests.llama2_model import Transformer, ModelArgs

init_method, world_size, rank, executor, device, dtype, _unused_bucketing_strategy = input_data
init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data
devicetype = devices.device_from_string(device).devicetype
_unused_dtype = ltorch.to_torch_dtype(dtype)
init_per_process_distributed(init_method, devicetype, world_size, rank)
Expand Down Expand Up @@ -1454,6 +1458,139 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
return None


def _test_fsdp_transformer_engine(input_data):
# Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))`
# model with thunder (using TE executor) and with PyTorch eager + TE
# and verify that the weights have converged to same value and
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data
thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"]
devicetype = devices.device_from_string(device).devicetype

# Setting LOCAL_RANK is necessary for thunder.distributed.fsdp
with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}, clear=True):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
init_per_process_distributed(init_method, devicetype, world_size, rank)
torch.cuda.set_device(rank)

dim = 256
n_iter = 10

class ThunderModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(dim, dim, bias=False)
self.fc2 = torch.nn.Linear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

# Weights
fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda()
fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda()
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

# Inputs (different input on different rank).
if rank == 0:
x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()
if rank == 1:
x = torch.randn(dim, dim).cuda() * 100
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

thunder_model = ThunderModel().cuda()
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
thunder_model.fc1.weight.data = fc1_weight.clone()
thunder_model.fc2.weight.data = fc2_weight.clone()

jit_model = thunder.jit(
thunder.distributed.fsdp(thunder_model, sharding_strategy=thunder_fsdp_strategy),
executors=[
transformer_engine_ex,
]
+ executor.executors_list(),
)

optim = torch.optim.SGD(thunder_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = jit_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()

class TEModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = TELinear(dim, dim, bias=False)
self.fc2 = TELinear(dim, dim, bias=False)

def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))

te_model = TEModel().cuda()
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

ddp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy)

optim = torch.optim.SGD(te_model.parameters())

for _ in range(n_iter):
with fp8_autocast():
o = ddp_model(x).sum()

o.backward()
optim.step()
optim.zero_grad()

thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2}

fwd_traces = thunder.last_traces(jit_model)

def is_same_across_ranks(t):
t_clone = t.clone()
torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG)
assert_close(t, t_clone)

# Compare the state of the two models.
comparison_exceptions = []
for bound_symbol in fwd_traces[-1].bound_symbols:
if "te_linear" in bound_symbol.sym.name:
thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta
te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta
try:
# fwd tensor history
assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale)
assert_close(thunder_fp8_meta["scaling_fwd"].scale_inv, te_fp8_meta["scaling_fwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history)
# bwd tensor history
assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale)
assert_close(thunder_fp8_meta["scaling_bwd"].scale_inv, te_fp8_meta["scaling_bwd"].scale_inv)
assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history)

# This has to be on all ranks so that the computation is not blocked
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale_inv)
is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history)
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

# Compare weights after `n_iters`
shard_size = int(dim / world_size)
fsdp_te_params = tuple(te_model.parameters())
try:
assert_close(thunder_model.fc1.weight, fsdp_te_params[0].view(shard_size, dim))
assert_close(thunder_model.fc2.weight, fsdp_te_params[1].view(shard_size, dim))
except Exception as e:
# Return exceptions only for rank==0
if rank == 0:
comparison_exceptions.append(e)

return comparison_exceptions


# NOTE This is just a stub, see the NOTE for ddp_wrapper
@instantiate(
dtypes=(thunder.float32,),
Expand All @@ -1473,7 +1610,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
devicetypes=(devices.DeviceType.CUDA,),
decorators=(
pytest.mark.parametrize(
"bucket_size_in_mb",
"fsdp_bucketing_strategy",
(
FSDPBucketingStrategy.NONE,
FSDPBucketingStrategy.LAYER,
Expand All @@ -1483,7 +1620,7 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
),
)
@ddp_wrapper("test_native_fsdp", _test_native_fsdp_helper)
def test_native_fsdp(executor, devices, dtype, bucket_size_in_mb):
def test_native_fsdp(executor, devices, dtype, fsdp_bucketing_strategy):
pass


Expand Down Expand Up @@ -1528,5 +1665,30 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
pass


@instantiate(
dtypes=(thunder.float32,),
num_devices=2,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
decorators=(
# NOTE: ddp_wrapper
pytest.mark.parametrize(
"thunder_fsdp_strategy",
(
FSDPType.ZERO2,
FSDPType.ZERO3,
),
),
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# NOTE: Setting `NVTE_TORCH_COMPILE`
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}, clear=True),
),
)
@ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine)
def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy):
pass


if __name__ == "__main__":
common_utils.run_tests()
Loading