diff --git a/autoparallel/graph_passes/split_di_dw_graph.py b/autoparallel/graph_passes/split_di_dw_graph.py index ed6d6b9..a60f3ef 100644 --- a/autoparallel/graph_passes/split_di_dw_graph.py +++ b/autoparallel/graph_passes/split_di_dw_graph.py @@ -4,25 +4,18 @@ # LICENSE file in the root directory of this source tree. import copy -import itertools import operator +from dataclasses import dataclass -import sympy import torch import torch.fx as fx from torch._functorch.partitioners import ( - SavedForBackwardsAOTOutput, + PartitionedGraphSignature, + PartitionedGraphSignatureBuilder, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, - _is_backward_state, - _is_bwd_seed_offset, - _is_fwd_seed_offset, - _is_primal, - _remove_by_name, - find_symbol_binding_fx_nodes, - free_symbols, + _extract_graphs_from_partition_inputs, is_sym_node, - is_symbol_binding_fx_node, ) from torch.utils._ordered_set import OrderedSet @@ -64,130 +57,51 @@ def reorder_output_grads(bw_gm, num_weight_gradients): return len(grad_inputs) -# This is a copy of the function used by the default partitioner, -# which does *not* reorder symint activations. -# This is reordering is needed by the custom autograd.Function in AOTDispatcher, -# but isn't needed in our dI/dW splitting since there is no autograd in the loop. -# TODO: provide a way to gt this behavior automatically out of the default partitioner -def _extract_fwd_bwd_modules( +@dataclass +class DiDwPartitionSignature(PartitionedGraphSignature): + @classmethod + def _bwd_graph_inputs_preliminary( + cls, builder: PartitionedGraphSignatureBuilder + ) -> list[fx.Node]: + return ( + builder._saved_values + + builder._saved_sym_nodes + + builder._bwd_seed_offset_inputs + ) + + def bwd_graph_inputs(self) -> list[fx.Node]: + return self.saved_values + self.saved_sym_nodes + self.bwd_seed_offset_inputs + + +def _extract_fwd_bwd_modules_didw( joint_module: fx.GraphModule, saved_values: list[fx.Node], saved_sym_nodes: list[fx.Node], *, num_fwd_outputs: int, ) -> tuple[fx.GraphModule, fx.GraphModule]: - ( - fwd_outputs, - bwd_outputs, - fwd_outputs_descs, - bwd_outputs_descs, - ) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - placeholders = joint_module.graph.find_nodes(op="placeholder") - primal_inputs = [*filter(_is_primal, placeholders)] - fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] - bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] - backward_state_inputs = [*filter(_is_backward_state, placeholders)] - - bwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - saved_values + saved_sym_nodes + bwd_seed_offset_inputs, - bwd_outputs, - bwd_outputs_descs, - "backward", - ignore_must_be_in_fw_bw=True, - ) - - distributed_enabled = torch.distributed.is_available() - - for node in bwd_graph.find_nodes(op="placeholder"): - # This is to filter out saved values that don't actually end up being used by the backwards pass - if not node.users: - _remove_by_name(saved_values, node.name) - _remove_by_name(saved_sym_nodes, node.name) - # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, - # but this dead activation is actually a collective, - # then the collective will generally by followed by a wait_tensor() call. - # we need to peak one node further to see if this wait_tensor is dead as well. - elif distributed_enabled and all( - n.target is torch.ops._c10d_functional.wait_tensor.default - and len(n.users) == 0 - for n in node.users - ): - _remove_by_name(saved_values, node.name) - _remove_by_name(saved_sym_nodes, node.name) - elif _is_backward_state(node): - # BackwardState is saved directly - _remove_by_name(saved_values, node.name) - assert backward_state_inputs - - # Now that we have the finalized list of saved values, we need to ensure - # we propagate all symbols which are referenced by backwards inputs. - # These are not directly used in the graph but are required for downstream - # sizevar assignment - saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() - saved_sym_nodes_binding = [] - saved_sym_nodes_derived = [] - - # Some symbols may already be bound in the directly saved_sym_nodes, - # keep track of them so we don't re-bind them - for node in saved_sym_nodes: - symbol = is_symbol_binding_fx_node(node) - if symbol: - saved_symbols.add(symbol) - saved_sym_nodes_binding.append(node) - else: - saved_sym_nodes_derived.append(node) - - # Now go through all of the prospective backward inputs and track any - # other symbols we need to bind - symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) - for node in itertools.chain(saved_sym_nodes_derived, saved_values): - if "val" not in node.meta: - continue - new_symbols = free_symbols(node.meta["val"]) - saved_symbols - # NB: Deterministic order please! - for s in sorted(new_symbols, key=lambda s: s.name): - # NB: For well formed graphs, the symbol should always be present, - # but we also have ways to produce ill-formed graphs, e.g., direct - # make_fx usages, so don't choke in this case - if s not in symbol_bindings: - continue - saved_sym_nodes_binding.append(symbol_bindings[s]) - saved_symbols |= new_symbols - - # Update saved_sym_nodes that are now reordered to have all bindings at - # front. This can also be used later on to figure out the position of saved - # sym nodes in the output of fwd graph. - saved_sym_nodes.clear() - saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) - - # Now, we re-generate the fwd/bwd graphs. - # NB: This might increase compilation time, but I doubt it matters - fwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - primal_inputs + fwd_seed_offset_inputs, - fwd_outputs + saved_values + saved_sym_nodes, - fwd_outputs_descs - + [ - SavedForBackwardsAOTOutput(i) - for i in range(len(saved_values) + len(saved_sym_nodes)) - ], - "forward", - ignore_must_be_in_fw_bw=True, + """ + Extract forward and backward modules for dI/dW splitting. + + This uses DiDwPartitionSignatureBuilder to construct the signature + and _extract_graphs_from_partition_inputs to extract the graphs. + """ + builder = PartitionedGraphSignatureBuilder( + joint_module, saved_values, saved_sym_nodes, num_fwd_outputs ) - bwd_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, - saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs, - bwd_outputs, - bwd_outputs_descs, - "backward", - ignore_must_be_in_fw_bw=True, + builder.override_dataclass(DiDwPartitionSignature) + builder.filter_unused_bwd_inputs(ignore_must_be_in_fw_bw=True) + builder.resolve_symbol_bindings() + builder.separate_and_reorder_saved_values() + assert not builder._tangent_inputs + assert not builder._saved_opaque_objects + assert not builder._no_vc_check_start_idx + signature = builder.build() + + return _extract_graphs_from_partition_inputs( + joint_module, signature, ignore_must_be_in_fw_bw=True ) - fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) - bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) - return fwd_module, bwd_module - # TODO: in theory we can infer num_weight_gradients from the graph metadata directly def split_di_dw_graph( @@ -207,9 +121,6 @@ def split_di_dw_graph( args = list(bw_gm.graph.find_nodes(op="placeholder")) - # bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) - # return bw_inputs, bw_weights, num_input_gradients - ( grad_inps, grad_weights, @@ -257,7 +168,7 @@ def split_di_dw_graph( saved_values.append(node) saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - bw_inputs, bw_weights = _extract_fwd_bwd_modules( + bw_inputs, bw_weights = _extract_fwd_bwd_modules_didw( bw_gm, saved_values, saved_sym_nodes=saved_sym_nodes,