diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index b523cf68dd..04589961f4 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -34,6 +34,7 @@ transformer_engine_ex, TE_AVAILABLE, te_sync_fp8_meta_bwd, + TE_VERSION_1_8_PLUS, ) @@ -1596,8 +1597,7 @@ def _test_fsdp_transformer_engine(input_data): # and verify that the weights have converged to same value and # fp8 meta state is same after `n_iter`. init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data - thunder_fsdp_strategy = kwargs["thunder_fsdp_strategy"] - intermediate_activation_sharding = kwargs["intermediate_activation_sharding"] + thunder_fsdp_strategy, intermediate_activation_sharding = kwargs["thunder_fsdp_strategy_and_intermediate_sharding"] devicetype = devices.device_from_string(device).devicetype # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp @@ -1667,7 +1667,7 @@ def forward(self, x): te_model.fc2.weight.data = fc2_weight.clone() fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy) - if thunder_fsdp_strategy == FSDPType.ZERO3 and intermediate_activation_sharding: + if intermediate_activation_sharding: transformer_engine.pytorch.distributed.prepare_te_modules_for_fsdp(fsdp_model) optim = torch.optim.SGD(te_model.parameters()) @@ -1819,18 +1819,14 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): decorators=( # NOTE: ddp_wrapper pytest.mark.parametrize( - "thunder_fsdp_strategy", + "thunder_fsdp_strategy_and_intermediate_sharding", ( - FSDPType.ZERO2, - FSDPType.ZERO3, + (FSDPType.ZERO2, False), + (FSDPType.ZERO3, False), + # Intermediate sharding is only availabe TE v1.8 onwards + *(((FSDPType.ZERO3, True),) if TE_VERSION_1_8_PLUS else ()), ), ), - # NOTE: `intermediate_activation_sharding` only works with `FSDP.Zero3`, - # and should be skipped with FSDP.Zero2. - pytest.mark.parametrize( - "intermediate_activation_sharding", - (False, True), - ), pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), # See NOTE: Setting `NVTE_TORCH_COMPILE` @@ -1840,7 +1836,7 @@ def test_ddp_transformer_engine_llama_sanity(executor, devices, dtype): ), ) @ddp_wrapper("test_fsdp_transformer_engine", _test_fsdp_transformer_engine) -def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy, intermediate_activation_sharding): +def test_fsdp_transformer_engine(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): pass