-
Notifications
You must be signed in to change notification settings - Fork 100
CUDAGraphs as executor/transform/fusion pass #656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4151315
to
1ee2fb5
Compare
e4c0efb
to
bd68121
Compare
0defe05
to
791dab8
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall, thank you @nikitaved .
I have one question in the comments and then I'd be all for merging it.
if bsym.sym.id in do_not_fuse_sym_set: | ||
return False | ||
|
||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we need to check for fixed sizes of the proxies in input and output or ist this handled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether the future API is agreed upon? Does it mean that dynamic shape Tensors will contain Integer proxies in their meta-data?
Nice! This change is needed to make my PR #214 work with CUDA Graphs correctly. Because there I try to put torch.autograd.Function.apply into the forward trace but it should be executed outside of the CUDA Graph-captured region. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @nikitaved!
I'm imagining that in the future we could consider not special-casing use_cudagraphs but keeping it as a transform, but maybe that's overkill and it's certainly it's not for now.
Let's merge and fix anything that needs fixing later. |
As per title. Fixes #635.
Also, it fixes the following subtle bugs:
CUDAGraphExecutor
- does not properly update static buffers when the same graph is invoked on inputs with meta-data that allows to fetch cached graphs, but with different storage data. The area of concern - training and the backward pass.horizontal_merge
in the fusion logic - that one, when grouping bound symbols, does not consider precedence between ops horizontally. It is not an issue with nvfusions, but it could cause issues when deciding whether to place something likedel x
afterop(x)
in a custom FusionExecutor. The fix sorts bsyms in each group wrt trace position (which is expected to be toposorted prior to any fusions) and, hence, restores the inter-/intra-bsym groups topological order.