Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable additional transforms for the PyTorch Autograd path #74

Merged

Conversation

IvanYashchuk
Copy link
Collaborator

Applying transforms on the trace after "forward backward split + transform_for_execution" is not a well-tested path and breaks the careful sorting of operations introduced inside the split_forward_backward function. In the future, we need to revisit the order of transformations and move execution transforms out of the split_forward_backward function. Today this PR is needed to restore distributed communication overlap with computation and unblock fixing TransformerEngine+DataParallel.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @IvanYashchuk

@t-vi
Copy link
Collaborator

t-vi commented Mar 25, 2024

So I generally agree with the patch, but the CI has opinions.

@IvanYashchuk
Copy link
Collaborator Author

Notebooks fail with:

     18 cache_rec, i_, _ = thunder.compile_data(thunder_model).get_computation_and_inputs(x)
---> 19 computation_trace = cache_rec.computation_traces[0]

IndexError: list index out of range

@IvanYashchuk
Copy link
Collaborator Author

Next error in the notebooks:

nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
------------------
### DON'T TRY THIS AT HOME
computation_trace.bound_symbols[2].sym = cache_rec.computation_traces[0].bound_symbols[2].subsymbols[0].sym
if cache_rec.computation_traces[0].bound_symbols[3].subsymbols:
    computation_trace.bound_symbols[3] = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0]
computation_trace.bound_symbols[4].sym = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0].sym

wrap_as_highlighted_code(computation_trace)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) March 26, 2024 10:21
@IvanYashchuk IvanYashchuk merged commit bdf5c3f into Lightning-AI:main Mar 26, 2024
37 checks passed
@github-actions github-actions bot deleted the disable-additional-transforms branch July 17, 2024 00:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants