Skip to content

Commit

Permalink
Fix a ddp test for #191. (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Jul 9, 2024
1 parent d903cb4 commit 6dfe7e9
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def test_fsdp_with_padding(
):

from thunder.core.prims import PrimIDs
from thunder.core.transforms import unwrap_one_level_of_subsymbols
from thunder.executors.torchex import pad_prim_impl
from thunder.executors.torchex import slice_prim_impl

Expand All @@ -771,10 +772,13 @@ def forward(self, x):
y.mean().backward()

fw_extrace = thunder.last_traces(jitted)[-1]
# When bookend is turned off, `slice` and `pad` may appear in nvFusion subsymbols.
fw_extrace = unwrap_one_level_of_subsymbols(fw_extrace)
fw_symids = [bsym.sym.id for bsym in fw_extrace.bound_symbols]
self.assertTrue(any(sym_id in {PrimIDs.SLICE, slice_prim_impl.id} for sym_id in fw_symids))

bw_trace = thunder.last_backward_traces(jitted)[0]
bw_trace = unwrap_one_level_of_subsymbols(bw_trace)
bw_symids = [bsym.sym.id for bsym in bw_trace.bound_symbols]
self.assertTrue(any(sym_id in {PrimIDs.PAD, pad_prim_impl.id} for sym_id in bw_symids))

Expand Down

0 comments on commit 6dfe7e9

Please sign in to comment.