From 6dfe7e939a19d1ef5ab259de8709a79f0104fa42 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 9 Jul 2024 00:26:06 -0700 Subject: [PATCH] Fix a ddp test for #191. (#730) --- thunder/tests/distributed/test_ddp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 04589961f4..3fe33b98a1 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -750,6 +750,7 @@ def test_fsdp_with_padding( ): from thunder.core.prims import PrimIDs + from thunder.core.transforms import unwrap_one_level_of_subsymbols from thunder.executors.torchex import pad_prim_impl from thunder.executors.torchex import slice_prim_impl @@ -771,10 +772,13 @@ def forward(self, x): y.mean().backward() fw_extrace = thunder.last_traces(jitted)[-1] + # When bookend is turned off, `slice` and `pad` may appear in nvFusion subsymbols. + fw_extrace = unwrap_one_level_of_subsymbols(fw_extrace) fw_symids = [bsym.sym.id for bsym in fw_extrace.bound_symbols] self.assertTrue(any(sym_id in {PrimIDs.SLICE, slice_prim_impl.id} for sym_id in fw_symids)) bw_trace = thunder.last_backward_traces(jitted)[0] + bw_trace = unwrap_one_level_of_subsymbols(bw_trace) bw_symids = [bsym.sym.id for bsym in bw_trace.bound_symbols] self.assertTrue(any(sym_id in {PrimIDs.PAD, pad_prim_impl.id} for sym_id in bw_symids))