From 847085966799d6351dbead31e560054b144edb69 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 24 Mar 2024 14:08:09 -0700 Subject: [PATCH] reflect old jit signature changes in test_ddp (#67) --- thunder/tests/distributed/test_ddp.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 3562a10a5c..03bedfd334 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -184,7 +184,7 @@ def func( e = c @ b + a return e, d - cfunc = thunder.jit(func, executors_list=_executor.executors_list()) + cfunc = thunder.jit(func, executors=_executor.executors_list()) device = f"cuda:{self.rank}" a = make_tensor((2, 2), device=device, dtype=torch.float32) b = make_tensor((2, 2), device=device, dtype=torch.float32) @@ -258,7 +258,7 @@ def lc_foo( process_group = c10d.new_group() # NOTE Preprocessing is disabled because we call thunder.torch operations directly - cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list()) + cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for op, async_op in product((None, torch.distributed.ReduceOp.SUM), (False, True)): expected = foo(a, b, op, process_group, async_op) @@ -310,7 +310,7 @@ def lc_foo( process_group = c10d.new_group() # NOTE Preprocessing is disabled because we call thunder.torch operations directly - cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list()) + cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for async_op in (True, False): expected = foo(a, b, process_group, async_op) @@ -368,7 +368,7 @@ def lc_foo( process_group = c10d.new_group() # NOTE Preprocessing is disabled because we call thunder.torch operations directly - cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list()) + cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for async_op in (True, False): expected = foo(a, b, process_group, async_op) @@ -425,7 +425,7 @@ def lc_foo( process_group = c10d.new_group() # NOTE Preprocessing is disabled because we call thunder.torch operations directly - cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list()) + cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for op, async_op in product((None, torch.distributed.ReduceOp.SUM), (False, True)): expected = foo(a, b, op, process_group, async_op) @@ -447,7 +447,7 @@ def test_ddp_grad_bucketing(self, executor, bucket_size_in_mb: int): m = ToyModel().to(device) cm = thunder.jit( ddp(m, bucket_size_in_mb=bucket_size_in_mb), - executors_list=executors_map[executor].executors_list(), + executors=executors_map[executor].executors_list(), ) x = torch.ones((2, 12)).to(device) cm(x).mean().backward() @@ -551,7 +551,7 @@ def get_model_and_optimizer(device): compiled_ddp_m = thunder.jit( ddp_m, cache_mode=CACHE_OPTIONS.CONSTANT_VALUES, - executors_list=executors_map[executor].executors_list(), + executors=executors_map[executor].executors_list(), ) optimizer = torch.optim.SGD(compiled_ddp_m.parameters(), lr=1e-3) return compiled_ddp_m, optimizer @@ -642,7 +642,7 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor): m.load_state_dict(initial_model_state) cm = thunder.jit( ddp(m, bucket_size_in_mb=bucket_size_in_mb), - executors_list=executors_map[executor].executors_list(), + executors=executors_map[executor].executors_list(), ) x = torch.ones((2, 12)).to(device) cm(x).mean().backward() @@ -680,7 +680,7 @@ def test_fsdp_grad_parity_with_without_bucketing( m.load_state_dict(initial_model_state) cm = thunder.jit( fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype), - executors_list=executors_map[executor].executors_list(), + executors=executors_map[executor].executors_list(), ) x = torch.ones((2, 12), device=device) loss = cm(x).mean() @@ -822,8 +822,7 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool): m = Block(config).to(device=device) cm = thunder.jit( fsdp(m, device=device, broadcast_from=0, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype), - interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON, - executors_list=executors_map[executor].executors_list(), + executors=executors_map[executor].executors_list(), ) x = torch.ones((2, config.block_size, config.n_embd), device=device) loss = cm(x).mean() @@ -1065,7 +1064,7 @@ def _test_native_ddp_helper(input_data): ddp_model = ddp(model) cmodel = thunder.jit( ddp_model, - executors_list=executor.executors_list(), + executors=executor.executors_list(), ) comparison_exceptions = [] @@ -1174,7 +1173,7 @@ def finalize_pg(pg): cmodel = thunder.jit( fsdp_model, - executors_list=executor.executors_list(), + executors=executor.executors_list(), ) comparison_exceptions = []