From 570a30264df9ae35e97f8da269a7c64386e98f0b Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 12 Jun 2024 16:09:03 +0200 Subject: [PATCH 1/4] Sort allgathers according to consumer order, reduce scatter according to producer order (#574) --- thunder/distributed/utils.py | 173 ++++++++++++++++++---------- thunder/executors/torch_autograd.py | 15 ++- 2 files changed, 125 insertions(+), 63 deletions(-) diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index 7e0e81abb0..ba52f381af 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -51,64 +51,6 @@ def key(node: Node) -> int: return primal_trace -def sort_communication_ops(execution_trace): - """ - Sorts the wait_prim_impl nodes in the execution trace to be as far from the - communication ops as possible, except for the all_gather_prim_impl nodes, the all_gather_prim_impl nodes - are sorted to be next to wait_prim_impl node to reduce the peak allocated memory - - Args: - execution_trace (TraceCtx): The execution trace to sort. - - Returns: - TraceCtx: The sorted execution trace. - """ - from thunder.executors.torchex import ( - wait_prim_impl, - reduce_scatter_prim_impl, - all_reduce_prim_impl, - all_gather_prim_impl, - unpack_for_fsdp_prim_impl, - ) - - if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols): - return execution_trace - - order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)} - - def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) -> int: - # Prefer communication ops other than "all_gather_prim_impl" over other nodes and prefer other - # nodes over "wait_prim_impl", pick "all_gather_prim_impl" last. - def key(node: Node) -> int: - match node.bsym.sym.id: - case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id: - return len(order_in_trace) - case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: - # Prefer larger communication ops over smaller ones - return -node.bsym.args[0].numel - case all_gather_prim_impl.id: - return len(order_in_trace) + order_in_trace[node.bsym] - case _: - # Prefer nodes that are earlier in the trace - return order_in_trace[node.bsym] - - return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i])) - - new_execution_trace = from_trace(execution_trace) - - # TODO: This pass doesn't behave correctly if del nodes are present in the trace - check( - not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols), - lambda: "Cannot sort execution trace with del nodes", - ) - new_execution_trace.bound_symbols = toposort_bsym_dag( - bsym_list_to_dag(execution_trace.bound_symbols)[1], - TOPOSORT_ORDER.BOTTOM_UP, - selector=prefer_comm_over_other_over_wait_over_allgather, - ) - return new_execution_trace - - def sort_waits(execution_trace): """ Sorts the wait_prim_impl nodes in the execution trace to be as far from the @@ -266,3 +208,118 @@ def limit_in_flight_allgathers( new_execution_trace.bound_symbols = new_bsyms return new_execution_trace + + +def sort_allgathers(execution_trace): + """ + Sort the all_gather_prim_impl and its wait nodes according to the consumer order, + and put all_gather_prim_impl just before wait + + Args: + execution_trace (TraceCtx): The execution trace to sort. + + Returns: + TraceCtx: The sorted execution trace. + """ + from thunder.executors.torchex import ( + wait_prim_impl, + reduce_scatter_prim_impl, + all_reduce_prim_impl, + all_gather_prim_impl, + unpack_for_fsdp_prim_impl, + ) + from thunder.core import utils + + if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols): + return execution_trace + + order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)} + producers, consumers = utils.producers_and_consumers(execution_trace) + + def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) -> int: + # bottom-up topological sorting, prefer allgather and wait for topological equal nodes + def key(node: Node) -> int: + match node.bsym.sym.id: + case wait_prim_impl.id: + if producers[node.bsym.flat_proxy_args[0]].sym.id == all_gather_prim_impl.id: + return len(order_in_trace) + return order_in_trace[node.bsym] + case all_gather_prim_impl.id: + # return -node.bsym.args[0].numel # chose smaller first, max distance + return len(order_in_trace) + order_in_trace[node.bsym] # use the consumer order, min distance + case _: + # Prefer nodes that are earlier in the trace + return order_in_trace[node.bsym] + + return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i])) + + new_execution_trace = from_trace(execution_trace) + + # TODO: This pass doesn't behave correctly if del nodes are present in the trace + check( + not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols), + lambda: "Cannot sort execution trace with del nodes", + ) + new_execution_trace.bound_symbols = toposort_bsym_dag( + bsym_list_to_dag(execution_trace.bound_symbols)[1], + TOPOSORT_ORDER.BOTTOM_UP, + selector=prefer_comm_over_other_over_wait_over_allgather, + ) + return new_execution_trace + + +def sort_reduce_ops(execution_trace): + """ + Sort the reduce/reduce_scatter and its wait node according to the producer order, + and maximum the distance between reduce/reduce_scatter and wait + + Args: + execution_trace (TraceCtx): The execution trace to sort. + + Returns: + TraceCtx: The sorted execution trace. + """ + from thunder.executors.torchex import ( + wait_prim_impl, + reduce_scatter_prim_impl, + all_reduce_prim_impl, + all_gather_prim_impl, + ) + from thunder.core import utils + + if not any(bsym.sym.id == wait_prim_impl.id for bsym in execution_trace.bound_symbols): + return execution_trace + + order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)} + producers, consumers = utils.producers_and_consumers(execution_trace) + + def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int: + # top-down topological sorting, prefer reduce/reduce_scatter and pick wait at last for topological equal nodes + def key(node: Node) -> int: + match node.bsym.sym.id: + case wait_prim_impl.id: + if producers[node.bsym.flat_proxy_args[0]].sym.id in (reduce_scatter_prim_impl.id, all_reduce_prim_impl.id): # all_gather_prim_impl.id: + return len(order_in_trace) + return order_in_trace[node.bsym] + case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: + # Prefer larger communication ops over smaller ones + return -node.bsym.args[0].numel + case _: + # Prefer nodes that are earlier in the trace + return order_in_trace[node.bsym] + + return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i])) + + new_execution_trace = from_trace(execution_trace) + + # TODO: This pass doesn't behave correctly if del nodes are present in the trace + check( + not any(bsym.sym.name == "del" for bsym in execution_trace.bound_symbols), + lambda: "Cannot sort execution trace with del nodes", + ) + new_execution_trace.bound_symbols = toposort_bsym_dag( + bsym_list_to_dag(execution_trace.bound_symbols)[0], + TOPOSORT_ORDER.TOP_DOWN, + selector=prefer_comm_over_other_over_wait, + ) + return new_execution_trace \ No newline at end of file diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 820359f4d4..b78e94651e 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -110,7 +110,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing - from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops + from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_allgathers, sort_reduce_ops from thunder.executors.passes import del_last_used, transform_for_execution utils.check(compile_data is not None, lambda: "`compile_data` is required") @@ -223,13 +223,16 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed import FSDPBucketingStrategy from thunder.distributed.utils import limit_in_flight_allgathers - fw_extrace = sort_communication_ops(fw_extrace) + # fw_extrace = sort_communication_ops(fw_extrace) + fw_extrace = sort_allgathers(fw_extrace) fw_extrace = limit_in_flight_allgathers( fw_extrace, 3, compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - bw_extrace = sort_communication_ops(bw_extrace) + # bw_extrace = sort_communication_ops(bw_extrace) + bw_extrace = sort_reduce_ops(bw_extrace) + bw_extrace = sort_allgathers(bw_extrace) bw_extrace = limit_in_flight_allgathers( bw_extrace, 3, @@ -241,14 +244,16 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from sys import maxsize as INT_MAX # sort the allgather+wait as consumer order just before consumer - fw_extrace = sort_communication_ops(fw_extrace) + # fw_extrace = sort_communication_ops(fw_extrace) + fw_extrace = sort_allgathers(fw_extrace) # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait fw_extrace = limit_in_flight_allgathers( fw_extrace, INT_MAX, compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - bw_extrace = sort_waits(bw_extrace) + # bw_extrace = sort_waits(bw_extrace) + bw_extrace = sort_reduce_ops(bw_extrace) if getattr(compile_data.fn, "use_ddp", False): bw_extrace = sort_waits(bw_extrace) From d3497a674de84f81e8d62f0bf7fb28ec5d9741df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 13:41:26 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/distributed/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index ba52f381af..1c515351e6 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -237,7 +237,7 @@ def sort_allgathers(execution_trace): producers, consumers = utils.producers_and_consumers(execution_trace) def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) -> int: - # bottom-up topological sorting, prefer allgather and wait for topological equal nodes + # bottom-up topological sorting, prefer allgather and wait for topological equal nodes def key(node: Node) -> int: match node.bsym.sym.id: case wait_prim_impl.id: @@ -294,11 +294,14 @@ def sort_reduce_ops(execution_trace): producers, consumers = utils.producers_and_consumers(execution_trace) def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int: - # top-down topological sorting, prefer reduce/reduce_scatter and pick wait at last for topological equal nodes + # top-down topological sorting, prefer reduce/reduce_scatter and pick wait at last for topological equal nodes def key(node: Node) -> int: match node.bsym.sym.id: case wait_prim_impl.id: - if producers[node.bsym.flat_proxy_args[0]].sym.id in (reduce_scatter_prim_impl.id, all_reduce_prim_impl.id): # all_gather_prim_impl.id: + if producers[node.bsym.flat_proxy_args[0]].sym.id in ( + reduce_scatter_prim_impl.id, + all_reduce_prim_impl.id, + ): # all_gather_prim_impl.id: return len(order_in_trace) return order_in_trace[node.bsym] case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: @@ -322,4 +325,4 @@ def key(node: Node) -> int: TOPOSORT_ORDER.TOP_DOWN, selector=prefer_comm_over_other_over_wait, ) - return new_execution_trace \ No newline at end of file + return new_execution_trace From 313596d0627d47e71557a5db8a981763f312dab3 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 25 Jun 2024 05:23:22 -0700 Subject: [PATCH 3/4] rm comments --- thunder/distributed/utils.py | 5 ++--- thunder/executors/torch_autograd.py | 4 ---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index 1c515351e6..fc1e0c56bf 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -245,8 +245,7 @@ def key(node: Node) -> int: return len(order_in_trace) return order_in_trace[node.bsym] case all_gather_prim_impl.id: - # return -node.bsym.args[0].numel # chose smaller first, max distance - return len(order_in_trace) + order_in_trace[node.bsym] # use the consumer order, min distance + return len(order_in_trace) + order_in_trace[node.bsym] case _: # Prefer nodes that are earlier in the trace return order_in_trace[node.bsym] @@ -301,7 +300,7 @@ def key(node: Node) -> int: if producers[node.bsym.flat_proxy_args[0]].sym.id in ( reduce_scatter_prim_impl.id, all_reduce_prim_impl.id, - ): # all_gather_prim_impl.id: + ): return len(order_in_trace) return order_in_trace[node.bsym] case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index b0830a75eb..36106dfa9e 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -227,14 +227,12 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed import FSDPBucketingStrategy from thunder.distributed.utils import limit_in_flight_allgathers - # fw_extrace = sort_communication_ops(fw_extrace) fw_extrace = sort_allgathers(fw_extrace) fw_extrace = limit_in_flight_allgathers( fw_extrace, 3, compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - # bw_extrace = sort_communication_ops(bw_extrace) bw_extrace = sort_reduce_ops(bw_extrace) bw_extrace = sort_allgathers(bw_extrace) bw_extrace = limit_in_flight_allgathers( @@ -248,7 +246,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from sys import maxsize as INT_MAX # sort the allgather+wait as consumer order just before consumer - # fw_extrace = sort_communication_ops(fw_extrace) fw_extrace = sort_allgathers(fw_extrace) # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait fw_extrace = limit_in_flight_allgathers( @@ -256,7 +253,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat INT_MAX, compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - # bw_extrace = sort_waits(bw_extrace) bw_extrace = sort_reduce_ops(bw_extrace) if getattr(compile_data.fn, "use_ddp", False): bw_extrace = sort_waits(bw_extrace) From 8e87c4e8bc141af49212572d899e8d6fafea1cae Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 25 Jun 2024 06:55:29 -0700 Subject: [PATCH 4/4] fix for bucketing --- thunder/distributed/utils.py | 17 +++++++++++++++++ thunder/executors/torch_autograd.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index fc1e0c56bf..52c90eabe9 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -244,6 +244,12 @@ def key(node: Node) -> int: if producers[node.bsym.flat_proxy_args[0]].sym.id == all_gather_prim_impl.id: return len(order_in_trace) return order_in_trace[node.bsym] + case unpack_for_fsdp_prim_impl.id: + if producers[node.bsym.flat_proxy_args[0]].sym.id == wait_prim_impl.id: + waitop = producers[node.bsym.flat_proxy_args[0]] + if producers[waitop.flat_proxy_args[0]].sym.id == all_gather_prim_impl.id: + return len(order_in_trace) + return order_in_trace[node.bsym] case all_gather_prim_impl.id: return len(order_in_trace) + order_in_trace[node.bsym] case _: @@ -283,6 +289,7 @@ def sort_reduce_ops(execution_trace): reduce_scatter_prim_impl, all_reduce_prim_impl, all_gather_prim_impl, + pack_for_fsdp_prim_impl, ) from thunder.core import utils @@ -306,6 +313,16 @@ def key(node: Node) -> int: case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: # Prefer larger communication ops over smaller ones return -node.bsym.args[0].numel + case pack_for_fsdp_prim_impl.id: + pack_consumer = consumers.get(node.bsym.flat_proxy_outs[0], None) + check( + pack_consumer is not None and len(pack_consumer) == 1, + lambda: f"Pack operator should have one consumer", + ) + # Prefer larger communication ops over smaller ones + if pack_consumer[0].sym.id in (reduce_scatter_prim_impl.id, all_reduce_prim_impl.id): + return -node.bsym.flat_proxy_outs[0].numel + return order_in_trace[node.bsym] case _: # Prefer nodes that are earlier in the trace return order_in_trace[node.bsym] diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 36106dfa9e..7df89396af 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -233,8 +233,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat 3, compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - bw_extrace = sort_reduce_ops(bw_extrace) bw_extrace = sort_allgathers(bw_extrace) + bw_extrace = sort_reduce_ops(bw_extrace) bw_extrace = limit_in_flight_allgathers( bw_extrace, 3,