Skip to content

Commit

Permalink
apply sort_waits if dist_prims.wait is in a trace (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Jul 19, 2024
1 parent b1cf1bf commit 721e28e
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
13 changes: 13 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,24 @@ def get_computation_and_inputs(*args, **kwargs):
computation_traces.append(computation_trc)
cs.last_computation_transformation_stop = time.perf_counter_ns()

from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits

with langctxs.langctx(cd.langctx):
tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list)
is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc)
if is_transformed:
computation_trc = tmp_comp_trc
computation_traces.append(computation_trc)

with langctxs.langctx(cd.langctx):
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=True,
)
computation_traces.append(computation_trc)
computation_trc = extraces[-1]

if not compile_options.get("disable_inplace_copy_check", False):
Expand Down
32 changes: 30 additions & 2 deletions thunder/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def key(node: Node) -> int:
return primal_trace


def has_wait_prims(trace: TraceCtx) -> bool:
from thunder.executors.torchex import wait_prim_impl

return any(bsym.sym.id == wait_prim_impl.id for bsym in trace.bound_symbols)


def sort_communication_ops(execution_trace):
"""
Sorts the wait_prim_impl nodes in the execution trace to be as far from the
Expand All @@ -71,7 +77,7 @@ def sort_communication_ops(execution_trace):
unpack_for_fsdp_prim_impl,
)

if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols):
if not has_wait_prims(execution_trace):
return execution_trace

order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}
Expand Down Expand Up @@ -122,14 +128,15 @@ def sort_waits(execution_trace):
Returns:
TraceCtx: The sorted execution trace.
"""
from thunder.core import prims
from thunder.executors.torchex import (
wait_prim_impl,
reduce_scatter_prim_impl,
all_reduce_prim_impl,
all_gather_prim_impl,
)

if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols):
if not has_wait_prims(execution_trace):
return execution_trace

order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}
Expand All @@ -144,6 +151,10 @@ def key(node: Node) -> int:
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
# note(crcrpar): When a dist collective comm is applied on a func arg and the arg is not included in return,
# this sort could put `wait` after `return` stmt.
case prims.PrimIDs.RETURN | prims.python_return.id:
return len(order_in_trace) + 1
case _:
# Prefer nodes that are earlier in the trace
return order_in_trace[node.bsym]
Expand All @@ -165,6 +176,23 @@ def key(node: Node) -> int:
return new_execution_trace


def maybe_sort_waits(trace: TraceCtx) -> tuple[bool, TraceCtx]:
"""Apply ``sort_waits`` to ``trace`` if possible.
The condition to apply :func:`~thunder.distributed.utils.sort_waits` is that :mod:`torch.distributed`
is available and at least :func:`thunder.distributed.prims.wait` is in ``trace``.
"""
from torch.distributed import is_available
from thunder.core.trace import TraceProvenance

if is_available() and has_wait_prims(trace):
trace_with_waits_sorted = sort_waits(trace)
trace_with_waits_sorted.set_provenance(TraceProvenance("Sort Waits"))
return True, trace_with_waits_sorted
else:
return False, trace


def limit_in_flight_allgathers(
execution_trace: TraceCtx,
max_in_flight_comms: int,
Expand Down
11 changes: 9 additions & 2 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# For performance we need the wait_prim_impl nodes in the execution trace to be as far from the
# communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of
# backward trace and increases the peak allocated memory
if getattr(compile_data.fn, "use_fsdp", False):
use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False)
if use_fsdp:
assert hasattr(compile_data.fn, "sharding_strategy")
if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3:
from thunder.distributed import FSDPBucketingStrategy
Expand Down Expand Up @@ -249,8 +250,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE,
)
bw_extrace = sort_waits(bw_extrace)
if getattr(compile_data.fn, "use_ddp", False):
use_ddp: bool = getattr(compile_data.fn, "use_ddp", False)
if use_ddp:
bw_extrace = sort_waits(bw_extrace)
if (not use_ddp) and (not use_fsdp):
from thunder.distributed.utils import maybe_sort_waits

_, fw_extrace = maybe_sort_waits(fw_extrace)
_, bw_extrace = maybe_sort_waits(bw_extrace)

# Importing here to avoid cyclical dependencies in future.
from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex
Expand Down
54 changes: 54 additions & 0 deletions thunder/tests/distributed/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,60 @@ def lc_foo(
actual = cfoo(a, b, op, process_group, async_op, dim=dim)
self.assertEqual(actual, expected)

@common_utils.parametrize(
"executor,op",
product(tuple(executors_map.keys()), ("all_gather_into_tensor", "reduce_scatter_tensor")),
)
def test_native_collective_comms(self, executor, op):
from thunder.executors.torchex import all_gather_prim_impl, reduce_scatter_prim_impl, wait_prim_impl

device = f"cuda:{self.rank}"
shape = (4, 2)
output_shape = (8, 2) if op.startswith("all_gather") else (2, 2)

comm = getattr(torch.distributed, op)
_executor = executors_map[executor]

def foo(
a: torch.Tensor,
b: torch.Tensor,
output: torch.Tensor,
group: torch.distributed.ProcessGroup,
):
c = a + b
handle = comm(output, c, group=group, async_op=True)
e = c + 1
handle.wait()
f = e * b
output *= 2
return f

group = torch.distributed.distributed_c10d._get_default_group()
jitted = _executor.make_callable(foo)
a = make_tensor(shape, device=device, dtype=torch.float32)
b = make_tensor(shape, device=device, dtype=torch.float32)
output = torch.empty(output_shape, device=device, dtype=torch.float32)
a_, b_, output_ = a.clone().detach(), b.clone().detach(), output.clone().detach()

f = jitted(a, b, output, group)
f_ = foo(a_, b_, output_, group)
torch.testing.assert_close(f, f_)
torch.testing.assert_close(output, output_)

traces = thunder.last_traces(jitted)
trace_with_waits_sorted = None
for t in traces:
if t._provenance is not None and t._provenance.pss == "Sort Waits":
trace_with_waits_sorted = t
break

comm_idx = len(t.bound_symbols)
for idx, bsym in enumerate(trace_with_waits_sorted.bound_symbols):
if bsym.sym.id in {all_gather_prim_impl.id, reduce_scatter_prim_impl.id}:
comm_idx = idx
if bsym.sym.id == wait_prim_impl.id:
self.assertGreater(idx, comm_idx + 2)


common_utils.instantiate_parametrized_tests(DistributedCollectiveOpTest)

Expand Down

0 comments on commit 721e28e

Please sign in to comment.