Attach cached cudagraph callable to the transform#1001
Conversation
for more information, see https://pre-commit.ci
IvanYashchuk
left a comment
There was a problem hiding this comment.
The code is organized a bit differently now. Looking forward to where it will take us!
| self.python_callables[x_fn_name] = ( | ||
| self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs), | ||
| static_inputs_mask, | ||
| ) | ||
| self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs) |
There was a problem hiding this comment.
How can I access these dictionaries in the regular user script? How should developers get to know that this information is saved and available for inspection?
There was a problem hiding this comment.
Currently, you have to look at the source. I'm not yet sure that this is the exact information we want to keep, we should refine this.
There was a problem hiding this comment.
so in addition to everything else the x_fn_name is a bit stupid, the trick was that the transform is applied to forward and backward. If we are OK with cudagraph regions not starting from 1 in the backward, I would just generate the name in the runner and return it.
This would allow a link from trace->cache name here easily. Is that what you have in mind?
| def build_cuda_graph( | ||
| self, fn: Callable, args: list[any], static_args_mask: tuple[bool, ...] | ||
| ) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: |
There was a problem hiding this comment.
What is the benefit of having this as a method and not as a free function?
There was a problem hiding this comment.
My thinking was that having the ability to override the buffer allocation (get_static_buffer) could get useful. We could have it as a callback instead but then it might just as well be a method.
| static_inputs_mask = (True,) * self.num_static_inputs + (False,) * (len(args) - self.num_static_inputs) | ||
| else: | ||
| static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args) | ||
| def make_python_callable_from_symbols( |
There was a problem hiding this comment.
This looks like something we could have in a transform building block library in the future
There was a problem hiding this comment.
Indeed! The question is whether we should build the region before...
|
So I'm merging this now, I expect to do another PR or two in the next few days around the dev x of this + any further feedback. |
This introduces a CUDAGraphRunner class to bundle building and caching cudagraphs and attach them to the CUDAGraphTransform introduced in #977 .
(Design issue #981 )
From my POV this is ready for review, but it builds on #977, so I'm labeling it draft.