-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attach cached cudagraph callable to the transform #1001
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is organized a bit differently now. Looking forward to where it will take us!
self.python_callables[x_fn_name] = ( | ||
self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs), | ||
static_inputs_mask, | ||
) | ||
self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can I access these dictionaries in the regular user script? How should developers get to know that this information is saved and available for inspection?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, you have to look at the source. I'm not yet sure that this is the exact information we want to keep, we should refine this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so in addition to everything else the x_fn_name
is a bit stupid, the trick was that the transform is applied to forward and backward. If we are OK with cudagraph regions not starting from 1 in the backward, I would just generate the name in the runner and return it.
This would allow a link from trace->cache name here easily. Is that what you have in mind?
def build_cuda_graph( | ||
self, fn: Callable, args: list[any], static_args_mask: tuple[bool, ...] | ||
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the benefit of having this as a method and not as a free function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was that having the ability to override the buffer allocation (get_static_buffer) could get useful. We could have it as a callback instead but then it might just as well be a method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
static_inputs_mask = (True,) * self.num_static_inputs + (False,) * (len(args) - self.num_static_inputs) | ||
else: | ||
static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args) | ||
def make_python_callable_from_symbols( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like something we could have in a transform building block library in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed! The question is whether we should build the region before...
So I'm merging this now, I expect to do another PR or two in the next few days around the dev x of this + any further feedback. |
This introduces a CUDAGraphRunner class to bundle building and caching cudagraphs and attach them to the CUDAGraphTransform introduced in #977 .
(Design issue #981 )
From my POV this is ready for review, but it builds on #977, so I'm labeling it draft.