Conversation
| fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter, num_static_inputs) | ||
| fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter) |
There was a problem hiding this comment.
we are missing num_static_inputs here for a reason? If we get several non-isomorphic graphs, we should probably think about how to handle this parameter better, maybe through a callback. But that's not relevant now...
There was a problem hiding this comment.
I do think that this is one of the things where the default transform is not ideal, but it is not something that was used before: the parameter was there with a default, but there was no way of providing it.
(And I am not sure that it is correct to have a nontrivial trace-global parameter for it either.)
There was a problem hiding this comment.
It is used in the backward pass. We are not loosing it there, are we?
There was a problem hiding this comment.
But yes, the design of this parameter is so-so as it was not expected to have had graph breaks...
There was a problem hiding this comment.
but which were the args covered by it?
There was a problem hiding this comment.
So, this is the removed code:
if cd.use_cudagraphs:
from thunder.executors.cudagraphex import cudagraphex
computation_trc = cudagraphex.fusion_pass(computation_trc)
computation_traces.append(computation_trc)
if backward_trc is not None:
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
backward_traces.append(backward_trc)
There was a problem hiding this comment.
So it seems that these were the saved for backwards of (???). Is the assumption here that they were either static in the forward (parameters) or copied to the input area of the forward cuda graph?
But with the old code:
import torch, thunder
with torch.device("cuda"):
m = torch.nn.Linear(2, 3)
inp = torch.randn(1, 2, requires_grad=True)
jm = thunder.jit(m, use_cudagraphs=True)
res = jm(inp)
grads = torch.autograd.grad(res.sum(), (inp, *m.parameters()))
the forward has no cudagraph, so in that case, having the input as a fixed is not really correct. (admittedly, a corner case).
There was a problem hiding this comment.
I guess the happy part is that we will get the parameters as static anyways and we will have to look into the buffers for our own good...
crcrpar
left a comment
There was a problem hiding this comment.
would there be a plan to add a test which demonstrates the composability with e.g. fsdp?
That would be awesome, I really need to get the env for running distributed tests back. :( |
|
So as we talk about plans:
To my mind, this greatly improves our starting position (similar to but smaller than Nik's work to make this a (then hardcoded) transform instead of a callable wrapper). |
|
This is cool; how do you see cuda graphs as a transform? If we conceptually separate the transforms like this: Where would the cuda graphs transform go? If it's a post-execution (or destructive) transform, should we think about how executors might label their operations as being cuda graphs compatible (or not)? If it's a pre-execution transform, then does it preclude using certain executors in the execution transform? |
|
So it is after transform for execution ("post optimization") but our experimentation (which informs some of the CUDAGraph work) caused us to go before the "del" pass. |
0907eab to
d526304
Compare
|
I have merged this after coordination with @IvanYashchuk , we will address further review comments in a follow-up. |
This introduces thunder.transform.cudagraph.CUDAGraphTransform to replace
use_cudagraphs=True.For more detailed control, users can subclass the transform and e.g. override
can_fuse.