Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort allgathers according to consumer order, reduce scatter according to producer order #592

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 134 additions & 58 deletions thunder/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,64 +51,6 @@ def key(node: Node) -> int:
return primal_trace


def sort_communication_ops(execution_trace):
"""
Sorts the wait_prim_impl nodes in the execution trace to be as far from the
communication ops as possible, except for the all_gather_prim_impl nodes, the all_gather_prim_impl nodes
are sorted to be next to wait_prim_impl node to reduce the peak allocated memory

Args:
execution_trace (TraceCtx): The execution trace to sort.

Returns:
TraceCtx: The sorted execution trace.
"""
from thunder.executors.torchex import (
wait_prim_impl,
reduce_scatter_prim_impl,
all_reduce_prim_impl,
all_gather_prim_impl,
unpack_for_fsdp_prim_impl,
)

if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols):
return execution_trace

order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}

def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) -> int:
# Prefer communication ops other than "all_gather_prim_impl" over other nodes and prefer other
# nodes over "wait_prim_impl", pick "all_gather_prim_impl" last.
def key(node: Node) -> int:
match node.bsym.sym.id:
case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
return len(order_in_trace)
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
case all_gather_prim_impl.id:
return len(order_in_trace) + order_in_trace[node.bsym]
case _:
# Prefer nodes that are earlier in the trace
return order_in_trace[node.bsym]

return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))

new_execution_trace = from_trace(execution_trace)

# TODO: This pass doesn't behave correctly if del nodes are present in the trace
check(
not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols),
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
bsym_list_to_dag(execution_trace.bound_symbols)[1],
TOPOSORT_ORDER.BOTTOM_UP,
selector=prefer_comm_over_other_over_wait_over_allgather,
)
return new_execution_trace


def sort_waits(execution_trace):
"""
Sorts the wait_prim_impl nodes in the execution trace to be as far from the
Expand Down Expand Up @@ -266,3 +208,137 @@ def limit_in_flight_allgathers(

new_execution_trace.bound_symbols = new_bsyms
return new_execution_trace


def sort_allgathers(execution_trace):
"""
Sort the all_gather_prim_impl and its wait nodes according to the consumer order,
and put all_gather_prim_impl just before wait

Args:
execution_trace (TraceCtx): The execution trace to sort.

Returns:
TraceCtx: The sorted execution trace.
"""
from thunder.executors.torchex import (
wait_prim_impl,
reduce_scatter_prim_impl,
all_reduce_prim_impl,
all_gather_prim_impl,
unpack_for_fsdp_prim_impl,
)
from thunder.core import utils

if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols):
return execution_trace

order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}
producers, consumers = utils.producers_and_consumers(execution_trace)

def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) -> int:
# bottom-up topological sorting, prefer allgather and wait for topological equal nodes
def key(node: Node) -> int:
match node.bsym.sym.id:
case wait_prim_impl.id:
if producers[node.bsym.flat_proxy_args[0]].sym.id == all_gather_prim_impl.id:
return len(order_in_trace)
return order_in_trace[node.bsym]
case unpack_for_fsdp_prim_impl.id:
if producers[node.bsym.flat_proxy_args[0]].sym.id == wait_prim_impl.id:
waitop = producers[node.bsym.flat_proxy_args[0]]
if producers[waitop.flat_proxy_args[0]].sym.id == all_gather_prim_impl.id:
return len(order_in_trace)
return order_in_trace[node.bsym]
case all_gather_prim_impl.id:
return len(order_in_trace) + order_in_trace[node.bsym]
case _:
# Prefer nodes that are earlier in the trace
return order_in_trace[node.bsym]

return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))

new_execution_trace = from_trace(execution_trace)

# TODO: This pass doesn't behave correctly if del nodes are present in the trace
check(
not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols),
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
bsym_list_to_dag(execution_trace.bound_symbols)[1],
TOPOSORT_ORDER.BOTTOM_UP,
selector=prefer_comm_over_other_over_wait_over_allgather,
)
return new_execution_trace


def sort_reduce_ops(execution_trace):
"""
Sort the reduce/reduce_scatter and its wait node according to the producer order,
and maximum the distance between reduce/reduce_scatter and wait

Args:
execution_trace (TraceCtx): The execution trace to sort.

Returns:
TraceCtx: The sorted execution trace.
"""
from thunder.executors.torchex import (
wait_prim_impl,
reduce_scatter_prim_impl,
all_reduce_prim_impl,
all_gather_prim_impl,
pack_for_fsdp_prim_impl,
)
from thunder.core import utils

if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols):
return execution_trace

order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}
producers, consumers = utils.producers_and_consumers(execution_trace)

def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int:
# top-down topological sorting, prefer reduce/reduce_scatter and pick wait at last for topological equal nodes
def key(node: Node) -> int:
match node.bsym.sym.id:
case wait_prim_impl.id:
if producers[node.bsym.flat_proxy_args[0]].sym.id in (
reduce_scatter_prim_impl.id,
all_reduce_prim_impl.id,
):
return len(order_in_trace)
return order_in_trace[node.bsym]
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
case pack_for_fsdp_prim_impl.id:
pack_consumer = consumers.get(node.bsym.flat_proxy_outs[0], None)
check(
pack_consumer is not None and len(pack_consumer) == 1,
lambda: f"Pack operator should have one consumer",
)
# Prefer larger communication ops over smaller ones
if pack_consumer[0].sym.id in (reduce_scatter_prim_impl.id, all_reduce_prim_impl.id):
return -node.bsym.flat_proxy_outs[0].numel
return order_in_trace[node.bsym]
case _:
# Prefer nodes that are earlier in the trace
return order_in_trace[node.bsym]

return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))

new_execution_trace = from_trace(execution_trace)

# TODO: This pass doesn't behave correctly if del nodes are present in the trace
check(
not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols),
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
bsym_list_to_dag(execution_trace.bound_symbols)[0],
TOPOSORT_ORDER.TOP_DOWN,
selector=prefer_comm_over_other_over_wait,
)
return new_execution_trace
11 changes: 6 additions & 5 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward
from thunder.core.transforms import forward_and_backward_from_trace
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_allgathers, sort_reduce_ops
from thunder.executors.passes import del_last_used, transform_for_execution

utils.check(compile_data is not None, lambda: "`compile_data` is required")
Expand Down Expand Up @@ -227,13 +227,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed import FSDPBucketingStrategy
from thunder.distributed.utils import limit_in_flight_allgathers

fw_extrace = sort_communication_ops(fw_extrace)
fw_extrace = sort_allgathers(fw_extrace)
fw_extrace = limit_in_flight_allgathers(
fw_extrace,
3,
compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE,
)
bw_extrace = sort_communication_ops(bw_extrace)
bw_extrace = sort_allgathers(bw_extrace)
bw_extrace = sort_reduce_ops(bw_extrace)
bw_extrace = limit_in_flight_allgathers(
bw_extrace,
3,
Expand All @@ -245,14 +246,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from sys import maxsize as INT_MAX

# sort the allgather+wait as consumer order just before consumer
fw_extrace = sort_communication_ops(fw_extrace)
fw_extrace = sort_allgathers(fw_extrace)
# unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait
fw_extrace = limit_in_flight_allgathers(
fw_extrace,
INT_MAX,
compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE,
)
bw_extrace = sort_waits(bw_extrace)
bw_extrace = sort_reduce_ops(bw_extrace)
if getattr(compile_data.fn, "use_ddp", False):
bw_extrace = sort_waits(bw_extrace)

Expand Down
Loading