Skip to content

Commit

Permalink
make sure multiple tensors are passed to pack_for_fsdp
Browse files Browse the repository at this point in the history
at trace level

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Mar 27, 2024
1 parent 84ac1b0 commit e891579
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,24 @@ def test_fsdp_grad_parity_with_without_bucketing(
self.assertEqual(loss, orig_loss)
self.assertEqual(tuple(p.grad for p in cm.parameters() if p.grad is not None), gradients)

# Make sure that at least one of "pack" takes multiple tensors.
from thunder.executors.torchex import pack_for_fsdp_prim_impl
from thunder.distributed.prims import PrimIDs as DistPrimIDs

for ex_trace in (thunder.last_traces(cm)[-1], thunder.last_backward_traces(cm)[-1]):
pack_bsyms = list(
filter(
lambda bsym: bsym.sym.id in {DistPrimIDs.PACK_FOR_FSDP, pack_for_fsdp_prim_impl.id},
ex_trace.bound_symbols,
)
)
has_pack_multiple_tensors = False
for bsym in pack_bsyms:
first_arg = bsym.args[0]
self.assertIsInstance(first_arg, list)
has_pack_multiple_tensors |= len(first_arg) > 1
self.assertTrue(has_pack_multiple_tensors, msg=f"{[bsym.args[0] for bsym in pack_bsyms]=}")

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_fsdp_shard_unshard(self):
from thunder.distributed import _shard_params, _unshard_params
Expand Down

0 comments on commit e891579

Please sign in to comment.