Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach cached cudagraph callable to the transform #1001

Merged
merged 8 commits into from
Aug 20, 2024
Merged

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Aug 20, 2024

This introduces a CUDAGraphRunner class to bundle building and caching cudagraphs and attach them to the CUDAGraphTransform introduced in #977 .

(Design issue #981 )

From my POV this is ready for review, but it builds on #977, so I'm labeling it draft.

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 20, 2024

@nikitaved @IvanYashchuk

@t-vi t-vi mentioned this pull request Aug 20, 2024
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is organized a bit differently now. Looking forward to where it will take us!

thunder/transforms/cudagraph.py Outdated Show resolved Hide resolved
Comment on lines +160 to +164
self.python_callables[x_fn_name] = (
self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs),
static_inputs_mask,
)
self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I access these dictionaries in the regular user script? How should developers get to know that this information is saved and available for inspection?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, you have to look at the source. I'm not yet sure that this is the exact information we want to keep, we should refine this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in addition to everything else the x_fn_name is a bit stupid, the trick was that the transform is applied to forward and backward. If we are OK with cudagraph regions not starting from 1 in the backward, I would just generate the name in the runner and return it.
This would allow a link from trace->cache name here easily. Is that what you have in mind?

Comment on lines +52 to +54
def build_cuda_graph(
self, fn: Callable, args: list[any], static_args_mask: tuple[bool, ...]
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of having this as a method and not as a free function?

Copy link
Collaborator Author

@t-vi t-vi Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that having the ability to override the buffer allocation (get_static_buffer) could get useful. We could have it as a callback instead but then it might just as well be a method.

@IvanYashchuk IvanYashchuk changed the title attach cudagraph cache to the transform Attach cached cudagraph callable to the transform Aug 20, 2024
Base automatically changed from tom/cudagraphs_transform to main August 20, 2024 11:42
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

static_inputs_mask = (True,) * self.num_static_inputs + (False,) * (len(args) - self.num_static_inputs)
else:
static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args)
def make_python_callable_from_symbols(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like something we could have in a transform building block library in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! The question is whether we should build the region before...

@t-vi t-vi marked this pull request as ready for review August 20, 2024 11:52
@t-vi t-vi requested a review from mruberry as a code owner August 20, 2024 11:52
@t-vi
Copy link
Collaborator Author

t-vi commented Aug 20, 2024

So I'm merging this now, I expect to do another PR or two in the next few days around the dev x of this + any further feedback.
(One other thing I want to do is perhaps pass in the static argument mask which we might obtain from analysing the graph(s) - parameters+buffers from the module are static etc.)

@t-vi t-vi merged commit 7c425fe into main Aug 20, 2024
40 checks passed
@t-vi t-vi deleted the tom/cudagraphs-cache branch August 20, 2024 12:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants