Skip to content

Commit

Permalink
skip intermediate_sharding test for v1.7
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Jul 4, 2024
1 parent a12cbf7 commit a2353ec
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
transformer_engine_ex,
TE_AVAILABLE,
te_sync_fp8_meta_bwd,
TE_VERSION_1_8_PLUS,
)


Expand Down Expand Up @@ -1596,8 +1597,7 @@ def _test_fsdp_transformer_engine(input_data):
# and verify that the weights have converged to same value and
# fp8 meta state is same after `n_iter`.
init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data
thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"]
intermediate_activation_sharding = kwargs["intermediate_activation_sharding"]
thunder_fsdp_strategy, intermediate_activation_sharding = kwargs["thunder_fsdp_strategy_and_intermediate_sharding"]
devicetype = devices.device_from_string(device).devicetype

# Setting LOCAL_RANK is necessary for thunder.distributed.fsdp
Expand Down Expand Up @@ -1667,7 +1667,7 @@ def forward(self, x):
te_model.fc2.weight.data = fc2_weight.clone()

fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy)
if thunder_fsdp_strategy == FSDPType.ZERO3 and intermediate_activation_sharding:
if intermediate_activation_sharding:
transformer_engine.pytorch.distributed.prepare_te_modules_for_fsdp(fsdp_model)
optim = torch.optim.SGD(te_model.parameters())

Expand Down Expand Up @@ -1819,18 +1819,14 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
decorators=(
# NOTE: ddp_wrapper
pytest.mark.parametrize(
"thunder_fsdp_strategy",
"thunder_fsdp_strategy_and_intermediate_sharding",
(
FSDPType.ZERO2,
FSDPType.ZERO3,
(FSDPType.ZERO2, False),
(FSDPType.ZERO3, False),
# Intermediate sharding is only availabe TE v1.8 onwards
*(((FSDPType.ZERO3, True),) if TE_VERSION_1_8_PLUS else ()),
),
),
# NOTE: `intermediate_activation_sharding` only works with `FSDP.Zero3`,
# and should be skipped with FSDP.Zero2.
pytest.mark.parametrize(
"intermediate_activation_sharding",
(False, True),
),
pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."),
pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason),
# See NOTE: Setting `NVTE_TORCH_COMPILE`
Expand All @@ -1840,7 +1836,7 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype):
),
)
@ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine)
def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy, intermediate_activation_sharding):
def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding):
pass


Expand Down

0 comments on commit a2353ec

Please sign in to comment.