Skip to content

Commit

Permalink
reflect old jit signature changes in test_ddp (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Mar 24, 2024
1 parent f5e9d2b commit 8470859
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def func(
e = c @ b + a
return e, d

cfunc = thunder.jit(func, executors_list=_executor.executors_list())
cfunc = thunder.jit(func, executors=_executor.executors_list())
device = f"cuda:{self.rank}"
a = make_tensor((2, 2), device=device, dtype=torch.float32)
b = make_tensor((2, 2), device=device, dtype=torch.float32)
Expand Down Expand Up @@ -258,7 +258,7 @@ def lc_foo(
process_group = c10d.new_group()

# NOTE Preprocessing is disabled because we call thunder.torch operations directly
cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list())
cfoo = thunder.jit(lc_foo, executors=_executor.executors_list())

for op, async_op in product((None, torch.distributed.ReduceOp.SUM), (False, True)):
expected = foo(a, b, op, process_group, async_op)
Expand Down Expand Up @@ -310,7 +310,7 @@ def lc_foo(
process_group = c10d.new_group()

# NOTE Preprocessing is disabled because we call thunder.torch operations directly
cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list())
cfoo = thunder.jit(lc_foo, executors=_executor.executors_list())

for async_op in (True, False):
expected = foo(a, b, process_group, async_op)
Expand Down Expand Up @@ -368,7 +368,7 @@ def lc_foo(
process_group = c10d.new_group()

# NOTE Preprocessing is disabled because we call thunder.torch operations directly
cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list())
cfoo = thunder.jit(lc_foo, executors=_executor.executors_list())

for async_op in (True, False):
expected = foo(a, b, process_group, async_op)
Expand Down Expand Up @@ -425,7 +425,7 @@ def lc_foo(
process_group = c10d.new_group()

# NOTE Preprocessing is disabled because we call thunder.torch operations directly
cfoo = thunder.jit(lc_foo, executors_list=_executor.executors_list())
cfoo = thunder.jit(lc_foo, executors=_executor.executors_list())

for op, async_op in product((None, torch.distributed.ReduceOp.SUM), (False, True)):
expected = foo(a, b, op, process_group, async_op)
Expand All @@ -447,7 +447,7 @@ def test_ddp_grad_bucketing(self, executor, bucket_size_in_mb: int):
m = ToyModel().to(device)
cm = thunder.jit(
ddp(m, bucket_size_in_mb=bucket_size_in_mb),
executors_list=executors_map[executor].executors_list(),
executors=executors_map[executor].executors_list(),
)
x = torch.ones((2, 12)).to(device)
cm(x).mean().backward()
Expand Down Expand Up @@ -551,7 +551,7 @@ def get_model_and_optimizer(device):
compiled_ddp_m = thunder.jit(
ddp_m,
cache_mode=CACHE_OPTIONS.CONSTANT_VALUES,
executors_list=executors_map[executor].executors_list(),
executors=executors_map[executor].executors_list(),
)
optimizer = torch.optim.SGD(compiled_ddp_m.parameters(), lr=1e-3)
return compiled_ddp_m, optimizer
Expand Down Expand Up @@ -642,7 +642,7 @@ def test_ddp_grad_parity_with_without_bucketing(self, executor):
m.load_state_dict(initial_model_state)
cm = thunder.jit(
ddp(m, bucket_size_in_mb=bucket_size_in_mb),
executors_list=executors_map[executor].executors_list(),
executors=executors_map[executor].executors_list(),
)
x = torch.ones((2, 12)).to(device)
cm(x).mean().backward()
Expand Down Expand Up @@ -680,7 +680,7 @@ def test_fsdp_grad_parity_with_without_bucketing(
m.load_state_dict(initial_model_state)
cm = thunder.jit(
fsdp(m, device=device, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype),
executors_list=executors_map[executor].executors_list(),
executors=executors_map[executor].executors_list(),
)
x = torch.ones((2, 12), device=device)
loss = cm(x).mean()
Expand Down Expand Up @@ -822,8 +822,7 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool):
m = Block(config).to(device=device)
cm = thunder.jit(
fsdp(m, device=device, broadcast_from=0, bucketing_strategy=bucketing_strategy, sharding_strategy=fsdptype),
interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON,
executors_list=executors_map[executor].executors_list(),
executors=executors_map[executor].executors_list(),
)
x = torch.ones((2, config.block_size, config.n_embd), device=device)
loss = cm(x).mean()
Expand Down Expand Up @@ -1065,7 +1064,7 @@ def _test_native_ddp_helper(input_data):
ddp_model = ddp(model)
cmodel = thunder.jit(
ddp_model,
executors_list=executor.executors_list(),
executors=executor.executors_list(),
)

comparison_exceptions = []
Expand Down Expand Up @@ -1174,7 +1173,7 @@ def finalize_pg(pg):

cmodel = thunder.jit(
fsdp_model,
executors_list=executor.executors_list(),
executors=executor.executors_list(),
)

comparison_exceptions = []
Expand Down

0 comments on commit 8470859

Please sign in to comment.