Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename rematerialization of saved for backward symbols #1367

Closed
wants to merge 4 commits into from

Conversation

riccardofelluga
Copy link
Collaborator

@riccardofelluga riccardofelluga commented Oct 30, 2024

This is part of #1232. PR renames the outputs of recomputed symbols so that they do not overlap with names used in the forward trace. Fusion rematerialization requires names used in producer and consumer fusions to be unique.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other reviewers, please don't merge this PR without my review.

@IvanYashchuk IvanYashchuk removed their assignment Nov 4, 2024
@@ -3148,6 +3148,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)

trace_tok = set_tracectx(bwd_trace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set and reset traces only with "try: finally:" blocks. If there's any error between the calls, the trace will not be reset.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you set the input bwd_trace as the active trace? There are no Thunder operations calls between set and reset, and the input trace shouldn't be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we not use with tracectx(bwd_trace)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this I've taken inspiration from the code in the torch_autograd executor, in particular these lines explain why the need to set the trace context:

# [note: why setting trace ctx?]
# [`TensorProxy.replace_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1221-L1223) calls
# [`tensorproxy`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1506-L1520)
# which then calls `TensorProxy.__init__`. `TensorProxy.__init__` of course calls
# [` Proxy.__init__`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86).
# `Proxy`'s dunder init calls [`make_proxy_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86)
# which depends on a tracectx.
trace_tok = set_tracectx(bwd_trace)

@IvanYashchuk Would an acceptable workaround be to create a new empty trace and use it as ctx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or allow creating Proxies with any name without active tracectx. Maybe all is needed is to return True if trc is None in this function

def register_proxy_name(name: None | str = None):
trc = get_tracectx()
if name is not None and not trc.has_name(name):
trc.add_name(name)
return True
return False

thunder/core/transforms.py Outdated Show resolved Hide resolved
thunder/core/transforms.py Show resolved Hide resolved
@IvanYashchuk IvanYashchuk self-assigned this Nov 4, 2024
@riccardofelluga
Copy link
Collaborator Author

To be noted this does not fix #1232 but it helps to debug it by having an overall clearer backward trace when remat saved for backward is enabled.

@t-vi
Copy link
Collaborator

t-vi commented Jan 10, 2025

@riccardofelluga I think this is superseded by doing the renaming in the rematerialization itself.
Please reopen if I am missing something (entierly possible).

@t-vi t-vi closed this Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants