Skip to content

Commit

Permalink
move import
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Jul 3, 2024
1 parent 57609f7 commit a12cbf7
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch import Linear as TELinear
from transformer_engine.pytorch.fp8 import check_fp8_support, FP8GlobalStateManager
import transformer_engine

is_fp8_supported, fp8_support_reason = check_fp8_support()

Expand Down Expand Up @@ -1665,8 +1666,6 @@ def forward(self, x):
te_model.fc1.weight.data = fc1_weight.clone()
te_model.fc2.weight.data = fc2_weight.clone()

import transformer_engine

fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy)
if thunder_fsdp_strategy == FSDPType.ZERO3 and intermediate_activation_sharding:
transformer_engine.pytorch.distributed.prepare_te_modules_for_fsdp(fsdp_model)
Expand Down

0 comments on commit a12cbf7

Please sign in to comment.