Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 221 additions & 6 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,203 @@
_APPLY_VIEW_MM_VIEW_PATTERN = False


def normalize_placeholder_name(name: str) -> str:
"""
Normalize a placeholder name to match the format of dynamo_flat_name_to_original_fqn keys.

Placeholder names come from: re.sub(r"[^a-zA-Z0-9]+", "_", source.name)
dynamo_flat_name_to_original_fqn keys come from: OutputGraph.module_key_name(source.name)

Both start from the same source name (e.g., L['self']._modules['linear']._parameters['weight'])
but apply different transformations:
- Placeholder: L_self___modules_linear___parameters_weight_
- module_key_name: self_linear_weight

We normalize the placeholder name to match the module_key_name format by:
1. Removing _modules_, _parameters_, _buffers_ patterns
2. Collapsing consecutive underscores
3. Removing the guard prefix (l_self_, L_self_, etc.)
4. Stripping leading/trailing underscores
"""
import re

# Remove _modules_, _parameters_, _buffers_ patterns
name = re.sub(r"_modules_", "_", name)
name = re.sub(r"_parameters_", "_", name)
name = re.sub(r"_buffers_", "_", name)
# Collapse multiple underscores
name = re.sub(r"_+", "_", name)
# Strip leading/trailing underscores
name = name.strip("_")
# Remove l_self_ or L_self_ prefix (common guard access pattern)
name = re.sub(r"^[lL]_self_", "", name)
return name


def create_graph_removing_unused_inputs_and_adding_unused_parameters(
src_gm: torch.fx.GraphModule,
fake_mode,
) -> tuple[torch.fx.GraphModule, list[torch.Tensor]]:
"""
Create a new GraphModule from src_gm where parameter/buffer placeholders
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not directly mutate src_gm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be happy to directly mutate src_gm, but there seemed to have a number of properties of the src_gm that would also need to be modified and which I didn't know before and didn't relate only to the underlying fx.Graph (the CodeGen IIRC, don't remember anymore).

The current implementation seems to behave as I would have expected, but I'm afraid it might be missing something that is expected for other parts of the system

are replaced with get_attr nodes.

Uses dynamo_flat_name_to_original_fqn metadata to map placeholder names to FQNs.

Returns (new_gm, inputs) where inputs are the non-param/buffer placeholder values.
"""
from torch._subclasses.fake_tensor import FakeTensor
from torch.export.unflatten import _assign_attr, _AttrKind

# Helper to create fresh fake tensors in the target fake mode
def to_fake(t: torch.Tensor) -> torch.Tensor:
with fake_mode:
fake_t = torch.empty_strided(
t.shape,
t.stride(),
dtype=t.dtype,
device=t.device,
requires_grad=t.requires_grad,
)
if isinstance(t, torch.nn.Parameter):
return torch.nn.Parameter(fake_t, requires_grad=t.requires_grad)
return fake_t

# Helper to convert example_value in node metadata to the target fake mode
def convert_node_meta(meta: dict) -> dict:
new_meta = meta.copy()
example_val = new_meta.get("example_value")
if example_val is not None:
if isinstance(example_val, FakeTensor) and hasattr(example_val, "shape"):
new_meta["example_value"] = to_fake(example_val)
elif isinstance(example_val, (list, tuple)):
# Handle tuple/list of tensors
converted = []
for v in example_val:
if isinstance(v, FakeTensor) and hasattr(v, "shape"):
converted.append(to_fake(v))
else:
converted.append(v)
new_meta["example_value"] = type(example_val)(converted)
return new_meta

# Create new GraphModule and register all parameters/buffers from src_gm
# Convert them to the target fake mode to avoid fake mode mixing
gm = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())

# Only copy specific metadata keys we need, avoiding tracing_context and other
# internal state that might hold references to dynamo's fake mode
for key in ["dynamo_flat_name_to_original_fqn", "module_call_specs"]:
if key in src_gm.meta:
gm.meta[key] = src_gm.meta[key]
gm.meta["fake_mode"] = fake_mode

for name, mod in src_gm.named_children():
if isinstance(mod, torch.fx.GraphModule):
gm.add_submodule(name, mod)
for fqn, param in src_gm.named_parameters():
_assign_attr(to_fake(param), gm, fqn, _AttrKind.PARAMETER)
for fqn, buf in src_gm.named_buffers():
_assign_attr(to_fake(buf), gm, fqn, _AttrKind.BUFFER)

# Build lookup from normalized name to FQN using dynamo_flat_name_to_original_fqn
# The keys in dynamo_flat_name_to_original_fqn are created by module_key_name(source.name)
# We normalize these keys to match our normalized placeholder names
flat_name_to_fqn = src_gm.meta.get("dynamo_flat_name_to_original_fqn", {})
normalized_to_fqn = {}
for flat_name, fqn in flat_name_to_fqn.items():
# module_key_name output is already mostly normalized, but we apply the same
# normalization to ensure consistency
normalized = normalize_placeholder_name(flat_name)
normalized_to_fqn[normalized] = fqn

param_fqns = {fqn for fqn, _ in src_gm.named_parameters()}
buffer_fqns = {fqn for fqn, _ in src_gm.named_buffers()}

graph = gm.graph
val_map = {}
inputs = []
used_params = set()
used_buffers = set()

for node in src_gm.graph.nodes:
if node.op == "placeholder":
example_val = node.meta.get("example_value")

# Try to find FQN by normalizing placeholder name and looking up
normalized_name = normalize_placeholder_name(node.name)
fqn = normalized_to_fqn.get(normalized_name) # type: ignore[assignment]

is_param = fqn is not None and fqn in param_fqns
is_buffer = fqn is not None and fqn in buffer_fqns

if is_param:
# Parameter placeholder -> get_attr
get_attr_node = graph.get_attr(fqn)
get_attr_node.meta = convert_node_meta(node.meta)
val_map[node] = get_attr_node
used_params.add(fqn)
elif is_buffer:
# Buffer placeholder -> get_attr
get_attr_node = graph.get_attr(fqn)
get_attr_node.meta = convert_node_meta(node.meta)
val_map[node] = get_attr_node
used_buffers.add(fqn)
else:
# Regular input placeholder - copy as-is
new_node = graph.node_copy(node, lambda n: val_map[n])
new_node.meta = convert_node_meta(node.meta)
if example_val is not None and hasattr(example_val, "shape"):
inputs.append(to_fake(example_val))
val_map[node] = new_node
else:
# Copy all other nodes and convert their metadata
new_node = graph.node_copy(node, lambda n: val_map[n])
new_node.meta = convert_node_meta(node.meta)
val_map[node] = new_node

# Add get_attr for unused parameters (not in the original graph)
# Insert before the first non-placeholder/non-get_attr node
insert_point = None
for node in graph.nodes:
if node.op not in ("placeholder", "get_attr"):
insert_point = node
break

for fqn in param_fqns - used_params:
param = src_gm.get_parameter(fqn)
if insert_point is not None:
with graph.inserting_before(insert_point):
get_attr_node = graph.get_attr(fqn)
with fake_mode:
get_attr_node.meta["example_value"] = torch.empty_strided(
param.shape,
param.stride(),
dtype=param.dtype,
device=param.device,
requires_grad=param.requires_grad,
)

for fqn in buffer_fqns - used_buffers:
buf = src_gm.get_buffer(fqn)
if insert_point is not None:
with graph.inserting_before(insert_point):
get_attr_node = graph.get_attr(fqn)
with fake_mode:
get_attr_node.meta["example_value"] = torch.empty_strided(
buf.shape,
buf.stride(),
dtype=buf.dtype,
device=buf.device,
requires_grad=buf.requires_grad,
)

graph.lint()
gm.recompile()

return gm, inputs


def _assign_attr(
attr: Any,
target_module: torch.nn.Module,
Expand Down Expand Up @@ -376,13 +573,26 @@ def build_model_graph(self):
with set_dtype_cast(
True
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs)
# TODO Cna't use fake mode here because it clashes with the user level
# fake mode. Ideally dynamo should reuse the user level fake mode.
from torch._dynamo.functional_export import dynamo_graph_capture_for_export

torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)(
*formatted_inputs
)
self.flatten_fn = torch_ir_with_fqn._dynamo_bytecode_flatten
self.unflatten_fn = torch_ir_with_fqn._dynamo_bytecode_unflatten
(
torch_ir_with_fqn2,
inputs2,
) = create_graph_removing_unused_inputs_and_adding_unused_parameters(
torch_ir_with_fqn, self.fake_mode
)
# Clear references to the original dynamo graph to help garbage collection
# and reduce state leakage between uses
del torch_ir_with_fqn
self.joint_with_descriptors = aot_export_joint_with_descriptors(
self.stack,
torch_ir_with_fqn,
formatted_inputs,
torch_ir_with_fqn2,
inputs2,
decompositions=decomp_table,
)
gm = self.joint_with_descriptors.graph_module
Expand Down Expand Up @@ -607,6 +817,8 @@ def apply_placement(self, sharding_placement=None):
bw_compiler=self.compiler_fn,
)

unflatten_fn = self.unflatten_fn

# TODO: this probably belongs in the AOTAutograd API
# TODO: pytree handling
class AutoParallelModule(torch.nn.Module):
Expand All @@ -624,10 +836,13 @@ def forward(self, *args):
dict(self.named_buffers(remove_duplicate=False)).items(),
)
]
boxed_args = [*params, *args]
# new_args = flatten_fn(*args)
filtered_args = [x for x in args if isinstance(x, torch.Tensor)]
boxed_args = [*params, *filtered_args]
del params
# NB: don't do self.parallel_model_fn work around Dynamo bug
out = parallel_model_fn(boxed_args)
out = unflatten_fn(out, args)
return out

self.parallel_model = AutoParallelModule()
Expand Down
8 changes: 6 additions & 2 deletions autoparallel/shardings/ordered_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_redistributed_input_placements(
x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node)
]
num_input_nodes = len(all_input_nodes)
curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]] = [
curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...], None]] = [
sharding_placement[n].output_specs for n in all_input_nodes
] # FIXME ?
if node.target == operator.getitem:
Expand Down Expand Up @@ -149,7 +149,11 @@ def compute_optimal_placement_order_for_parameters(module, sharding_placement):
param_and_grad_users = {}
param_grad_chain = {}
for param, grad in param_and_grad_nodes:
last_p = list(param.users)[0]
param_users = list(param.users)
if not param_users:
# if unused parameter, don't bother with it
continue
last_p = param_users[0]
p_chain = [param]
# get all linear chain of users of the parameter
while len(last_p.all_input_nodes) == 1:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,53 @@ def input_fn():
)


def test_non_tensor_input(device_mesh_1d):
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.linear = nn.Linear(dim, dim)
self.unused = nn.Parameter(torch.rand(1))

def forward(self, x, input_dim: int, has_compute: bool):
if has_compute:
return self.linear(x) + input_dim, x.shape
else:
return x

def init_weights(self):
dim = self.dim
self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0)
with torch.no_grad():
self.linear.bias.fill_(98.6)

def input_fn(divisor=1):
b = 512 // divisor
inputs = torch.rand(b, dim, device="cuda")
input_dim = 1
has_compute = True
return (inputs, input_dim, has_compute)

with torch.device("meta"):
model = Model(dim)
with AutoParallel(model, input_fn, device_mesh_1d, compile=True) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding, None])
sharding_placement = autop.optimize_placement()

parallel_mod = autop.apply_placement(sharding_placement)
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()
placeholders = autop.gm.graph.find_nodes(op="placeholder")
# 2 used parameters, 1 unused parameter, 1 input and 1 tangent
assert len(placeholders) == 5
inputs = input_fn(device_mesh_1d.shape[0])
out, shape = parallel_mod(*inputs)
assert input_fn()[0].shape == shape


def test_fx_graph_annotate(device_mesh_1d):
dim = 128

Expand Down
7 changes: 7 additions & 0 deletions tests/test_optimize_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

from autoparallel.api import AutoParallel

# @pytest.fixture(autouse=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xmfan it looks like I'm not properly resetting some global state in dynamo with my changes. This means that running a single test in isolation works fine, but running the full suite fails.

The failure disappears if I reset the dynamo state, but to me this seems like a fundamental issue.

Claude mentioned

Looking at this more carefully, the issue is that dynamo_graph_capture_for_export creates global caches that persist. The flatten_fn and unflatten_fn you're storing also likely hold references to internal dynamo state.

  There are a few things we can try:

  1. Explicitly delete the original graph after we're done with it
  2. Call dynamo reset in __exit__ to clean up after each AutoParallel usage

  Let me try the first approach - deleting the original graph after extracting what we need:

● Update(autoparallel/api.py)
  ⎿  Added 3 lines                                                                                                                          
      586               ) = create_graph_removing_unused_inputs_and_adding_unused_parameters(                                               
      587                   torch_ir_with_fqn, self.fake_mode                                                                               
      588               )                                                                                                                   
      589 +             # Clear references to the original dynamo graph to help garbage collection                                          
      590 +             # and reduce state leakage between uses                                                                             
      591 +             del torch_ir_with_fqn                                                                                               
      592               self.joint_with_descriptors = aot_export_joint_with_descriptors(                                                    
      593                   self.stack,                                                                                                     
      594                   torch_ir_with_fqn2,                                                                                             

● Please try running the tests again to see if deleting the original graph helps.

  However, I should be honest about the core issue: dynamo maintains global caches for compiled functions and their associated state. When you call dynamo_graph_capture_for_export, it:

  1. Compiles the function and stores it in a global cache (keyed by function identity)
  2. Creates internal state (fake modes, shape environments) that persists in these caches

  Even if we delete all local references, these global caches persist. This is intentional for performance - dynamo caches compiled functions to avoid recompiling them on subsequent calls.

  The torch._dynamo.reset() call clears these global caches, which is why it fixes the issue. This is the standard and recommended way to ensure test isolation when using dynamo - it's not a workaround for a bug in your code.

  If the tests still fail after the del, you have a few options:

  1. Keep the reset_dynamo fixture - This is the correct solution for test isolation with dynamo
  2. Add cleanup in __exit__ - Call torch._dynamo.reset() when exiting the AutoParallel context, but this could affect other dynamo users in the same process
  3. Accept that dynamo caches state - In production, this caching is desirable for performance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the python test implementation though, they reuse the same process for faster times. In core, we mark problematic tests with the reset either manually or with a decorator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you mean that this isn't something I should worry about, and just reset dynamo?

# def reset_dynamo():
# """Reset dynamo state before each test to ensure test isolation."""
# torch._dynamo.reset()
# yield
# torch._dynamo.reset()


@pytest.fixture(scope="module", autouse=True)
def init_pg():
Expand Down
Loading