Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make cudagraphs a transform #977

Merged
merged 5 commits into from
Aug 20, 2024
Merged

make cudagraphs a transform #977

merged 5 commits into from
Aug 20, 2024

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Aug 16, 2024

This introduces thunder.transform.cudagraph.CUDAGraphTransform to replace use_cudagraphs=True.

For more detailed control, users can subclass the transform and e.g. override can_fuse.

@t-vi t-vi requested review from mruberry and lantiga as code owners August 16, 2024 11:52
Comment on lines -233 to +237
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter, num_static_inputs)
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter)
Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are missing num_static_inputs here for a reason? If we get several non-isomorphic graphs, we should probably think about how to handle this parameter better, maybe through a callback. But that's not relevant now...

Copy link
Collaborator Author

@t-vi t-vi Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that this is one of the things where the default transform is not ideal, but it is not something that was used before: the parameter was there with a default, but there was no way of providing it.
(And I am not sure that it is correct to have a nontrivial trace-global parameter for it either.)

Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in the backward pass. We are not loosing it there, are we?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, right.

Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yes, the design of this parameter is so-so as it was not expected to have had graph breaks...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but which were the args covered by it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is the removed code:

            if cd.use_cudagraphs:
                from thunder.executors.cudagraphex import cudagraphex

                computation_trc = cudagraphex.fusion_pass(computation_trc)
                computation_traces.append(computation_trc)

                if backward_trc is not None:
                    backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
                    backward_traces.append(backward_trc)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that these were the saved for backwards of (???). Is the assumption here that they were either static in the forward (parameters) or copied to the input area of the forward cuda graph?

But with the old code:

import torch, thunder
with torch.device("cuda"):
    m = torch.nn.Linear(2, 3)
    inp = torch.randn(1, 2, requires_grad=True)
jm = thunder.jit(m, use_cudagraphs=True)

res = jm(inp)
grads = torch.autograd.grad(res.sum(), (inp, *m.parameters()))

the forward has no cudagraph, so in that case, having the input as a fixed is not really correct. (admittedly, a corner case).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the happy part is that we will get the parameters as static anyways and we will have to look into the buffers for our own good...

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would there be a plan to add a test which demonstrates the composability with e.g. fsdp?

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 16, 2024

would there be a plan to add a test which demonstrates the composability with e.g. fsdp?

That would be awesome, I really need to get the env for running distributed tests back. :(

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 16, 2024

So as we talk about plans:

  • refine the static inputs a bit,
  • refine the caching, in particular don't use a global cache, but a per-transform one, to let things go out of scope,
  • refine the operator selection for what is cuda-graphable,
  • check composibility - I'm guessing the communication prims don't mix with cudagraphs?

To my mind, this greatly improves our starting position (similar to but smaller than Nik's work to make this a (then hardcoded) transform instead of a callable wrapper).

@mruberry
Copy link
Collaborator

This is cool; how do you see cuda graphs as a transform? If we conceptually separate the transforms like this:

pre-execution transforms
execution transform
post-execution transforms
"destructive" transforms

Where would the cuda graphs transform go? If it's a post-execution (or destructive) transform, should we think about how executors might label their operations as being cuda graphs compatible (or not)? If it's a pre-execution transform, then does it preclude using certain executors in the execution transform?

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 16, 2024

So it is after transform for execution ("post optimization") but our experimentation (which informs some of the CUDAGraph work) caused us to go before the "del" pass.
I think it would be cool if symbols in general had a tag whether they are suitable for CUDAGraph inclusion. The current thing has the advantage that if people have particular ideas about this, they can just subclass and override can_fuse.
Also, CUDAGraphs could enormously benefit by more information about the memory effects of operators, but that is in the future.

@t-vi t-vi mentioned this pull request Aug 16, 2024
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@t-vi t-vi enabled auto-merge (squash) August 20, 2024 11:42
@t-vi t-vi merged commit e45e4b4 into main Aug 20, 2024
38 checks passed
@t-vi t-vi deleted the tom/cudagraphs_transform branch August 20, 2024 11:42
@t-vi
Copy link
Collaborator Author

t-vi commented Aug 20, 2024

I have merged this after coordination with @IvanYashchuk , we will address further review comments in a follow-up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants