-
Notifications
You must be signed in to change notification settings - Fork 15
Add support for arbitrary inputs and outputs in AutoParallel #311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Non-tensor inputs gets baked in the graph. Need to add an assert to ensure they haven't changed from user side
| fake_mode, | ||
| ) -> tuple[torch.fx.GraphModule, list[torch.Tensor]]: | ||
| """ | ||
| Create a new GraphModule from src_gm where parameter/buffer placeholders |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not directly mutate src_gm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be happy to directly mutate src_gm, but there seemed to have a number of properties of the src_gm that would also need to be modified and which I didn't know before and didn't relate only to the underlying fx.Graph (the CodeGen IIRC, don't remember anymore).
The current implementation seems to behave as I would have expected, but I'm afraid it might be missing something that is expected for other parts of the system
|
|
||
| from autoparallel.api import AutoParallel | ||
|
|
||
| # @pytest.fixture(autouse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xmfan it looks like I'm not properly resetting some global state in dynamo with my changes. This means that running a single test in isolation works fine, but running the full suite fails.
The failure disappears if I reset the dynamo state, but to me this seems like a fundamental issue.
Claude mentioned
Looking at this more carefully, the issue is that dynamo_graph_capture_for_export creates global caches that persist. The flatten_fn and unflatten_fn you're storing also likely hold references to internal dynamo state.
There are a few things we can try:
1. Explicitly delete the original graph after we're done with it
2. Call dynamo reset in __exit__ to clean up after each AutoParallel usage
Let me try the first approach - deleting the original graph after extracting what we need:
● Update(autoparallel/api.py)
⎿ Added 3 lines
586 ) = create_graph_removing_unused_inputs_and_adding_unused_parameters(
587 torch_ir_with_fqn, self.fake_mode
588 )
589 + # Clear references to the original dynamo graph to help garbage collection
590 + # and reduce state leakage between uses
591 + del torch_ir_with_fqn
592 self.joint_with_descriptors = aot_export_joint_with_descriptors(
593 self.stack,
594 torch_ir_with_fqn2,
● Please try running the tests again to see if deleting the original graph helps.
However, I should be honest about the core issue: dynamo maintains global caches for compiled functions and their associated state. When you call dynamo_graph_capture_for_export, it:
1. Compiles the function and stores it in a global cache (keyed by function identity)
2. Creates internal state (fake modes, shape environments) that persists in these caches
Even if we delete all local references, these global caches persist. This is intentional for performance - dynamo caches compiled functions to avoid recompiling them on subsequent calls.
The torch._dynamo.reset() call clears these global caches, which is why it fixes the issue. This is the standard and recommended way to ensure test isolation when using dynamo - it's not a workaround for a bug in your code.
If the tests still fail after the del, you have a few options:
1. Keep the reset_dynamo fixture - This is the correct solution for test isolation with dynamo
2. Add cleanup in __exit__ - Call torch._dynamo.reset() when exiting the AutoParallel context, but this could affect other dynamo users in the same process
3. Accept that dynamo caches state - In production, this caching is desirable for performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the python test implementation though, they reuse the same process for faster times. In core, we mark problematic tests with the reset either manually or with a decorator
Non-tensor inputs gets baked in the graph. Need to add an assert to ensure they haven't changed from user side.
The overall idea is to rely on the output from
dynamo_graph_capture_for_export, which has a_dynamo_bytecode_flattenand_dynamo_bytecode_unflattenand passes to its underlyingfx.Graphonly the required inputs, with ints / bools being baked in the graph already.I additionally add as well the unused parameters of the model as inputs to the graph, so that they can be sharded as well (just for consistency).
Subsumes #264