diff --git a/thunder/__init__.py b/thunder/__init__.py index b344af60b7..79866004e6 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -597,11 +597,24 @@ def get_computation_and_inputs(*args, **kwargs): computation_traces.append(computation_trc) cs.last_computation_transformation_stop = time.perf_counter_ns() + from thunder.executors.passes import transform_for_execution as transform_for_execution_pass + from thunder.executors.passes import _transform_for_operator_executor_execution + from thunder.distributed.utils import maybe_sort_waits + + with langctxs.langctx(cd.langctx): + tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list) + is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc) + if is_transformed: + computation_trc = tmp_comp_trc + computation_traces.append(computation_trc) + with langctxs.langctx(cd.langctx): extraces = transform_for_execution( computation_trc, executors_list=cd.executors_list, + use_del_last_used=True, ) + computation_traces.append(computation_trc) computation_trc = extraces[-1] if not compile_options.get("disable_inplace_copy_check", False): diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index 3ff33372c3..631e6f5a1c 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -51,6 +51,12 @@ def key(node: Node) -> int: return primal_trace +def has_wait_prims(trace: TraceCtx) -> bool: + from thunder.executors.torchex import wait_prim_impl + + return any(bsym.sym.id == wait_prim_impl.id for bsym in trace.bound_symbols) + + def sort_communication_ops(execution_trace): """ Sorts the wait_prim_impl nodes in the execution trace to be as far from the @@ -71,7 +77,7 @@ def sort_communication_ops(execution_trace): unpack_for_fsdp_prim_impl, ) - if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols): + if not has_wait_prims(execution_trace): return execution_trace order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)} @@ -122,6 +128,7 @@ def sort_waits(execution_trace): Returns: TraceCtx: The sorted execution trace. """ + from thunder.core import prims from thunder.executors.torchex import ( wait_prim_impl, reduce_scatter_prim_impl, @@ -129,7 +136,7 @@ def sort_waits(execution_trace): all_gather_prim_impl, ) - if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols): + if not has_wait_prims(execution_trace): return execution_trace order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)} @@ -144,6 +151,10 @@ def key(node: Node) -> int: case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id: # Prefer larger communication ops over smaller ones return -node.bsym.args[0].numel + # note(crcrpar): When a dist collective comm is applied on a func arg and the arg is not included in return, + # this sort could put `wait` after `return` stmt. + case prims.PrimIDs.RETURN | prims.python_return.id: + return len(order_in_trace) + 1 case _: # Prefer nodes that are earlier in the trace return order_in_trace[node.bsym] @@ -165,6 +176,23 @@ def key(node: Node) -> int: return new_execution_trace +def maybe_sort_waits(trace: TraceCtx) -> tuple[bool, TraceCtx]: + """Apply ``sort_waits`` to ``trace`` if possible. + + The condition to apply :func:`~thunder.distributed.utils.sort_waits` is that :mod:`torch.distributed` + is available and at least :func:`thunder.distributed.prims.wait` is in ``trace``. + """ + from torch.distributed import is_available + from thunder.core.trace import TraceProvenance + + if is_available() and has_wait_prims(trace): + trace_with_waits_sorted = sort_waits(trace) + trace_with_waits_sorted.set_provenance(TraceProvenance("Sort Waits")) + return True, trace_with_waits_sorted + else: + return False, trace + + def limit_in_flight_allgathers( execution_trace: TraceCtx, max_in_flight_comms: int, diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 820359f4d4..3aac706325 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -217,7 +217,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # For performance we need the wait_prim_impl nodes in the execution trace to be as far from the # communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of # backward trace and increases the peak allocated memory - if getattr(compile_data.fn, "use_fsdp", False): + use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False) + if use_fsdp: assert hasattr(compile_data.fn, "sharding_strategy") if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3: from thunder.distributed import FSDPBucketingStrategy @@ -249,8 +250,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) bw_extrace = sort_waits(bw_extrace) - if getattr(compile_data.fn, "use_ddp", False): + use_ddp: bool = getattr(compile_data.fn, "use_ddp", False) + if use_ddp: bw_extrace = sort_waits(bw_extrace) + if (not use_ddp) and (not use_fsdp): + from thunder.distributed.utils import maybe_sort_waits + + _, fw_extrace = maybe_sort_waits(fw_extrace) + _, bw_extrace = maybe_sort_waits(bw_extrace) # Importing here to avoid cyclical dependencies in future. from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex diff --git a/thunder/tests/distributed/test_ops.py b/thunder/tests/distributed/test_ops.py index b2ddd5d97d..cf5784b0b5 100644 --- a/thunder/tests/distributed/test_ops.py +++ b/thunder/tests/distributed/test_ops.py @@ -272,6 +272,60 @@ def lc_foo( actual = cfoo(a, b, op, process_group, async_op, dim=dim) self.assertEqual(actual, expected) + @common_utils.parametrize( + "executor,op", + product(tuple(executors_map.keys()), ("all_gather_into_tensor", "reduce_scatter_tensor")), + ) + def test_native_collective_comms(self, executor, op): + from thunder.executors.torchex import all_gather_prim_impl, reduce_scatter_prim_impl, wait_prim_impl + + device = f"cuda:{self.rank}" + shape = (4, 2) + output_shape = (8, 2) if op.startswith("all_gather") else (2, 2) + + comm = getattr(torch.distributed, op) + _executor = executors_map[executor] + + def foo( + a: torch.Tensor, + b: torch.Tensor, + output: torch.Tensor, + group: torch.distributed.ProcessGroup, + ): + c = a + b + handle = comm(output, c, group=group, async_op=True) + e = c + 1 + handle.wait() + f = e * b + output *= 2 + return f + + group = torch.distributed.distributed_c10d._get_default_group() + jitted = _executor.make_callable(foo) + a = make_tensor(shape, device=device, dtype=torch.float32) + b = make_tensor(shape, device=device, dtype=torch.float32) + output = torch.empty(output_shape, device=device, dtype=torch.float32) + a_, b_, output_ = a.clone().detach(), b.clone().detach(), output.clone().detach() + + f = jitted(a, b, output, group) + f_ = foo(a_, b_, output_, group) + torch.testing.assert_close(f, f_) + torch.testing.assert_close(output, output_) + + traces = thunder.last_traces(jitted) + trace_with_waits_sorted = None + for t in traces: + if t._provenance is not None and t._provenance.pss == "Sort Waits": + trace_with_waits_sorted = t + break + + comm_idx = len(t.bound_symbols) + for idx, bsym in enumerate(trace_with_waits_sorted.bound_symbols): + if bsym.sym.id in {all_gather_prim_impl.id, reduce_scatter_prim_impl.id}: + comm_idx = idx + if bsym.sym.id == wait_prim_impl.id: + self.assertGreater(idx, comm_idx + 2) + common_utils.instantiate_parametrized_tests(DistributedCollectiveOpTest)