Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 41 additions & 130 deletions autoparallel/graph_passes/split_di_dw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,18 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import operator
from dataclasses import dataclass

import sympy
import torch
import torch.fx as fx
from torch._functorch.partitioners import (
SavedForBackwardsAOTOutput,
PartitionedGraphSignature,
PartitionedGraphSignatureBuilder,
_extract_fwd_bwd_outputs,
_extract_graph_with_inputs_outputs,
_is_backward_state,
_is_bwd_seed_offset,
_is_fwd_seed_offset,
_is_primal,
_remove_by_name,
find_symbol_binding_fx_nodes,
free_symbols,
_extract_graphs_from_partition_inputs,
is_sym_node,
is_symbol_binding_fx_node,
)
from torch.utils._ordered_set import OrderedSet

Expand Down Expand Up @@ -64,130 +57,51 @@ def reorder_output_grads(bw_gm, num_weight_gradients):
return len(grad_inputs)


# This is a copy of the function used by the default partitioner,
# which does *not* reorder symint activations.
# This is reordering is needed by the custom autograd.Function in AOTDispatcher,
# but isn't needed in our dI/dW splitting since there is no autograd in the loop.
# TODO: provide a way to gt this behavior automatically out of the default partitioner
def _extract_fwd_bwd_modules(
@dataclass
class DiDwPartitionSignature(PartitionedGraphSignature):
@classmethod
def _bwd_graph_inputs_preliminary(
cls, builder: PartitionedGraphSignatureBuilder
) -> list[fx.Node]:
return (
builder._saved_values
+ builder._saved_sym_nodes
+ builder._bwd_seed_offset_inputs
)

def bwd_graph_inputs(self) -> list[fx.Node]:
return self.saved_values + self.saved_sym_nodes + self.bwd_seed_offset_inputs


def _extract_fwd_bwd_modules_didw(
joint_module: fx.GraphModule,
saved_values: list[fx.Node],
saved_sym_nodes: list[fx.Node],
*,
num_fwd_outputs: int,
) -> tuple[fx.GraphModule, fx.GraphModule]:
(
fwd_outputs,
bwd_outputs,
fwd_outputs_descs,
bwd_outputs_descs,
) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
placeholders = joint_module.graph.find_nodes(op="placeholder")
primal_inputs = [*filter(_is_primal, placeholders)]
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
backward_state_inputs = [*filter(_is_backward_state, placeholders)]

bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_values + saved_sym_nodes + bwd_seed_offset_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
ignore_must_be_in_fw_bw=True,
)

distributed_enabled = torch.distributed.is_available()

for node in bwd_graph.find_nodes(op="placeholder"):
# This is to filter out saved values that don't actually end up being used by the backwards pass
if not node.users:
_remove_by_name(saved_values, node.name)
_remove_by_name(saved_sym_nodes, node.name)
# wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
# but this dead activation is actually a collective,
# then the collective will generally by followed by a wait_tensor() call.
# we need to peak one node further to see if this wait_tensor is dead as well.
elif distributed_enabled and all(
n.target is torch.ops._c10d_functional.wait_tensor.default
and len(n.users) == 0
for n in node.users
):
_remove_by_name(saved_values, node.name)
_remove_by_name(saved_sym_nodes, node.name)
elif _is_backward_state(node):
# BackwardState is saved directly
_remove_by_name(saved_values, node.name)
assert backward_state_inputs

# Now that we have the finalized list of saved values, we need to ensure
# we propagate all symbols which are referenced by backwards inputs.
# These are not directly used in the graph but are required for downstream
# sizevar assignment
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
saved_sym_nodes_binding = []
saved_sym_nodes_derived = []

# Some symbols may already be bound in the directly saved_sym_nodes,
# keep track of them so we don't re-bind them
for node in saved_sym_nodes:
symbol = is_symbol_binding_fx_node(node)
if symbol:
saved_symbols.add(symbol)
saved_sym_nodes_binding.append(node)
else:
saved_sym_nodes_derived.append(node)

# Now go through all of the prospective backward inputs and track any
# other symbols we need to bind
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
for node in itertools.chain(saved_sym_nodes_derived, saved_values):
if "val" not in node.meta:
continue
new_symbols = free_symbols(node.meta["val"]) - saved_symbols
# NB: Deterministic order please!
for s in sorted(new_symbols, key=lambda s: s.name):
# NB: For well formed graphs, the symbol should always be present,
# but we also have ways to produce ill-formed graphs, e.g., direct
# make_fx usages, so don't choke in this case
if s not in symbol_bindings:
continue
saved_sym_nodes_binding.append(symbol_bindings[s])
saved_symbols |= new_symbols

# Update saved_sym_nodes that are now reordered to have all bindings at
# front. This can also be used later on to figure out the position of saved
# sym nodes in the output of fwd graph.
saved_sym_nodes.clear()
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)

# Now, we re-generate the fwd/bwd graphs.
# NB: This might increase compilation time, but I doubt it matters
fwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
primal_inputs + fwd_seed_offset_inputs,
fwd_outputs + saved_values + saved_sym_nodes,
fwd_outputs_descs
+ [
SavedForBackwardsAOTOutput(i)
for i in range(len(saved_values) + len(saved_sym_nodes))
],
"forward",
ignore_must_be_in_fw_bw=True,
"""
Extract forward and backward modules for dI/dW splitting.

This uses DiDwPartitionSignatureBuilder to construct the signature
and _extract_graphs_from_partition_inputs to extract the graphs.
"""
builder = PartitionedGraphSignatureBuilder(
joint_module, saved_values, saved_sym_nodes, num_fwd_outputs
)
bwd_graph = _extract_graph_with_inputs_outputs(
joint_module.graph,
saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
ignore_must_be_in_fw_bw=True,
builder.override_dataclass(DiDwPartitionSignature)
builder.filter_unused_bwd_inputs(ignore_must_be_in_fw_bw=True)
builder.resolve_symbol_bindings()
builder.separate_and_reorder_saved_values()
assert not builder._tangent_inputs
assert not builder._saved_opaque_objects
assert not builder._no_vc_check_start_idx
signature = builder.build()

return _extract_graphs_from_partition_inputs(
joint_module, signature, ignore_must_be_in_fw_bw=True
)

fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
return fwd_module, bwd_module


# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
def split_di_dw_graph(
Expand All @@ -207,9 +121,6 @@ def split_di_dw_graph(

args = list(bw_gm.graph.find_nodes(op="placeholder"))

# bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients)
# return bw_inputs, bw_weights, num_input_gradients

(
grad_inps,
grad_weights,
Expand Down Expand Up @@ -257,7 +168,7 @@ def split_di_dw_graph(
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
bw_inputs, bw_weights = _extract_fwd_bwd_modules(
bw_inputs, bw_weights = _extract_fwd_bwd_modules_didw(
bw_gm,
saved_values,
saved_sym_nodes=saved_sym_nodes,
Expand Down
Loading