-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding shape prim #1113
Adding shape prim #1113
Conversation
I'm concerned (for the same reasons I have been historically concerned) about putting @jjsjann123, maybe your thinking is more advanced than mine, but here are my concerns about putting
|
For point 1 and 2. I think the main concern is about how I think as you mentioned here that CSE should come as a good approach to tackle the unique producer of a proxy.
The next concern is about the scope of where definition of the number proxy is. I'm viewing Re point 3. Re point 4. |
I guess this is the part I'm most confused about. If we create a new NumberProxy from a shape query, then won't queries about that NumberProxy be difficult to combine into a set of queries for the actual shape?
I still like the idea of computing the shape symbolically in the prologue, but if you're into this approach then we may as well explore it further. |
Another question: what if the symbolic value of the length of a dimension of a tensor is used after the original tensor is out of scope? Would the query for the shape always precede this use and define the name properly in the resulting Python program? |
So when we discuss how we want to have computation of NumberProxy, one of the possibilities is just to annotate those as symbolic formulas. Which would allow us to identify identical queries.
Even though
Yeah, I think this is necessary for constraints when we introduce those. We will touch that topic when we get there soon~ish.
if after all the transformations, there's still a |
OK, but during execution, then, wouldn't we have to pass the symbolic value as an argument to the computation function? |
So for me, concretely, I would be very interested in seeing
work (or, more realistically, CausalSelfAttention, which includes a few reshape + permutes around the heads, query groups etc.). As far as I understand v0 of symbolic constraints could be that we have number proxies for the input shapes and whenever we hit an executor saying "I want this proxy to be a constant instead of a number input" that we constrain all symbols in the expression in the prologue. Not sure if NVFuser would be that executor, or it can deal with it via it's own caching + recompilation (it would be interesting to know, though, which inputs could trigger re-compilation), but e.g. the cudgraphs transform has a caching / checking that would deal with it automatically. |
I like how we are getting in-depth discussion on how we would want to support reshape in this PR. I think both approach that we discussed (lifting shape inference into prologue / leaving them in compute) would be able to support these workflow. The difference is about performance, which I'm not ready to answer yet. And this PR by itself doesn't make that decision neither. I have issue to track the conversation we are having. I think that's enough to unblock the continue review/merge of this PR for shape prim. |
From my POV, this PR is good to merge. |
FYI, I'm validating the performance on this PR with |
Sounds great, let's merge tomorrow if the benchmarking does not find issues. |
In the benchmark result: https://gist.github.com/jjsjann123/11ecc9f1d0ddc53b6df389a525e35373 comparing this branch with the main it's based on, I'm not seeing any significant difference in the median time between the two. I think we are good to go. cc'ing @mruberry as well as @tfogal . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems good, unless anyone else objects.
Thank you @jjsjann123 @mruberry
Adding
prims.shape
.Follow up on closed PR #1061, in comment, we went down the plan to unpack shape in compute trace before the use.
Access of shape will be visible in trace when we have NumberProxy enabled in #1027
Note that today since TensorProxy._shape are just constants, the shape prim should be DCE'ed away in the final trace. In the example below:
first compute trace looks like:
Future PR:
We'll need to enable it with nvfuser executor later. Right now it's tricky to do so, since
shape
prim would produce an scalar output, which nvfuser can't handle yet. We can expand fusion_pass to support that.