-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make cudagraphs a transform #977
Conversation
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter, num_static_inputs) | ||
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are missing num_static_inputs
here for a reason? If we get several non-isomorphic graphs, we should probably think about how to handle this parameter better, maybe through a callback. But that's not relevant now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think that this is one of the things where the default transform is not ideal, but it is not something that was used before: the parameter was there with a default, but there was no way of providing it.
(And I am not sure that it is correct to have a nontrivial trace-global parameter for it either.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used in the backward pass. We are not loosing it there, are we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch, right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But yes, the design of this parameter is so-so as it was not expected to have had graph breaks...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but which were the args covered by it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this is the removed code:
if cd.use_cudagraphs:
from thunder.executors.cudagraphex import cudagraphex
computation_trc = cudagraphex.fusion_pass(computation_trc)
computation_traces.append(computation_trc)
if backward_trc is not None:
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
backward_traces.append(backward_trc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems that these were the saved for backwards of (???). Is the assumption here that they were either static in the forward (parameters) or copied to the input area of the forward cuda graph?
But with the old code:
import torch, thunder
with torch.device("cuda"):
m = torch.nn.Linear(2, 3)
inp = torch.randn(1, 2, requires_grad=True)
jm = thunder.jit(m, use_cudagraphs=True)
res = jm(inp)
grads = torch.autograd.grad(res.sum(), (inp, *m.parameters()))
the forward has no cudagraph, so in that case, having the input as a fixed is not really correct. (admittedly, a corner case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the happy part is that we will get the parameters as static anyways and we will have to look into the buffers for our own good...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would there be a plan to add a test which demonstrates the composability with e.g. fsdp?
That would be awesome, I really need to get the env for running distributed tests back. :( |
So as we talk about plans:
To my mind, this greatly improves our starting position (similar to but smaller than Nik's work to make this a (then hardcoded) transform instead of a callable wrapper). |
This is cool; how do you see cuda graphs as a transform? If we conceptually separate the transforms like this:
Where would the cuda graphs transform go? If it's a post-execution (or destructive) transform, should we think about how executors might label their operations as being cuda graphs compatible (or not)? If it's a pre-execution transform, then does it preclude using certain executors in the execution transform? |
So it is after transform for execution ("post optimization") but our experimentation (which informs some of the CUDAGraph work) caused us to go before the "del" pass. |
0907eab
to
d526304
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
I have merged this after coordination with @IvanYashchuk , we will address further review comments in a follow-up. |
This introduces thunder.transform.cudagraph.CUDAGraphTransform to replace
use_cudagraphs=True
.For more detailed control, users can subclass the transform and e.g. override
can_fuse
.