Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding shape prim #1113

Merged
merged 12 commits into from
Sep 11, 2024
Merged

Adding shape prim #1113

merged 12 commits into from
Sep 11, 2024

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 6, 2024

Adding prims.shape.

Follow up on closed PR #1061, in comment, we went down the plan to unpack shape in compute trace before the use.

Access of shape will be visible in trace when we have NumberProxy enabled in #1027

Note that today since TensorProxy._shape are just constants, the shape prim should be DCE'ed away in the final trace. In the example below:

def foo(a, b):
    return a * b.size(0)

jfoo = thunder.jit(foo)
a = torch.randn(2, 2, device="cuda")
a.requires_grad_()
b = torch.randn(1, 2, device="cuda")
out = jfoo(a, b)

first compute trace looks like:

def computation(a, b):
  # a: "cuda:0 f32[2, 2]"
  # b: "cuda:0 f32[1, 2]"

  # /volume/thunder_dynamic/test.py:7:      return a * b.size(0)
  _ = ltorch.size(b, 0)
    # (_, _) = prims.size(b)
  t0 = ltorch.mul(a, 1)  # t0: "cuda:0 f32[2, 2]"
    # _ = prims.convert_element_type(1, float)
    # t0 = prims.mul(a, 1.0)  # t0: "cuda:0 f32[2, 2]"
  return t0

Future PR:
We'll need to enable it with nvfuser executor later. Right now it's tricky to do so, since shape prim would produce an scalar output, which nvfuser can't handle yet. We can expand fusion_pass to support that.

@jjsjann123 jjsjann123 changed the title [DO NOT REVIEW] Shape prim Adding shape prim Sep 6, 2024
@jjsjann123 jjsjann123 marked this pull request as ready for review September 6, 2024 22:29
thunder/torch/__init__.py Outdated Show resolved Hide resolved
@mruberry
Copy link
Collaborator

mruberry commented Sep 9, 2024

I'm concerned (for the same reasons I have been historically concerned) about putting shape in the trace.

@jjsjann123, maybe your thinking is more advanced than mine, but here are my concerns about putting shape in the trace:

  1. The shape of a tensor is frequently queried, and preserving each of those queries will mean multiple bound symbols produce the same number proxies. I think we implicitly assume that each proxy has a unique producer today (@IvanYashchuk probably knows the latest thinking here). By default, I think the shape queries would be CSE'd, although we can mark them as not eligible for CSE. If all but the first shape query are CSE'd, then I worry about the number proxies representing the lengths of a tensor showing up when they are not actually in scope for the Python program that Thunder generates. Part of this concern is that we're not completely sure how operations on symbolic values within primitives' meta functions will be accounted for.
  2. The shape of a tensor is frequently queried, and preserving each of these queries will be slower than caching the query result. Of course, this criticism could be minimized if we preserve only the first query in the computation, but that may have its own issues (see above)
  3. The shape of a tensor will almost certainly be queried in the prologue, so querying it again in the computation is redundant
  4. We would have to teach nvFuser (and maybe other fusion executors) how to fuse the shape query operations so that they don't prevent fusing operations around them

@jjsjann123
Copy link
Collaborator Author

For point 1 and 2. I think the main concern is about how

I think as you mentioned here that CSE should come as a good approach to tackle the unique producer of a proxy.

I worry about the number proxies representing the lengths of a tensor showing up when they are not actually in scope for the Python program that Thunder generates. Part of this concern is that we're not completely sure how operations on symbolic values within primitives' meta functions will be accounted for.

The next concern is about the scope of where definition of the number proxy is.
If we want to use operation in the trace to compute the number proxy, then that's a real concern. But if that logic is encoded in the number proxy itself, then we know that by the time a TensorProxy has been created, we would already have that populated.

I'm viewing prims.shape as a source node, i.e. query attribute from an existing Tensor/TensorProxy. Not necessarily that we want to associate it with the full trace of how a number proxy is produced, but rather creating a new NumberProxy. This should be useful by itself. i.e. thinking about a data-dependent shape that doesn't need to be constrained.

Re point 3.
That's true for input tensors. Is it true for intermediates as well? I think this will be determined by how generally we are resolving things like broadcast and other shape checks for intermediates.
I'm hoping that in cases where no shape-based control flow is executed, shape query would be limited to input tensors only and we do not have to lift the entire shape propagation into prologue.

Re point 4.
Yes, there is going to be extra work. I'm less worried about nvfuser handling shape ops that's contained within the fusion, but shape outputs that's leaked out of the fusion is a lot messier. link NVIDIA/Fuser#315

@mruberry
Copy link
Collaborator

mruberry commented Sep 9, 2024

The next concern is about the scope of where definition of the number proxy is. If we want to use operation in the trace to compute the number proxy, then that's a real concern. But if that logic is encoded in the number proxy itself, then we know that by the time a TensorProxy has been created, we would already have that populated.

I'm viewing prims.shape as a source node, i.e. query attribute from an existing Tensor/TensorProxy. Not necessarily that we want to associate it with the full trace of how a number proxy is produced, but rather creating a new NumberProxy. This should be useful by itself. i.e. thinking about a data-dependent shape that doesn't need to be constrained.

I guess this is the part I'm most confused about. If we create a new NumberProxy from a shape query, then won't queries about that NumberProxy be difficult to combine into a set of queries for the actual shape?

Re point 3. That's true for input tensors. Is it true for intermediates as well? I think this will be determined by how generally we are resolving things like broadcast and other shape checks for intermediates. I'm hoping that in cases where no shape-based control flow is executed, shape query would be limited to input tensors only and we do not have to lift the entire shape propagation into prologue.

I still like the idea of computing the shape symbolically in the prologue, but if you're into this approach then we may as well explore it further.

@mruberry
Copy link
Collaborator

mruberry commented Sep 9, 2024

Another question:

what if the symbolic value of the length of a dimension of a tensor is used after the original tensor is out of scope? Would the query for the shape always precede this use and define the name properly in the resulting Python program?

@jjsjann123
Copy link
Collaborator Author

The next concern is about the scope of where definition of the number proxy is. If we want to use operation in the trace to compute the number proxy, then that's a real concern. But if that logic is encoded in the number proxy itself, then we know that by the time a TensorProxy has been created, we would already have that populated.
I'm viewing prims.shape as a source node, i.e. query attribute from an existing Tensor/TensorProxy. Not necessarily that we want to associate it with the full trace of how a number proxy is produced, but rather creating a new NumberProxy. This should be useful by itself. i.e. thinking about a data-dependent shape that doesn't need to be constrained.

I guess this is the part I'm most confused about. If we create a new NumberProxy from a shape query, then won't queries about that NumberProxy be difficult to combine into a set of queries for the actual shape?

So when we discuss how we want to have computation of NumberProxy, one of the possibilities is just to annotate those as symbolic formulas. Which would allow us to identify identical queries.
i.e.

# assume t0 is an input
(s0, s1) = prims.shape(t0)  # s0: t0.size(0), s1: t0.size(1)
s2 = prims.mul(s0, s1)        # s2: t0.size(0) * t0.size(1)
(s3, s4) = prims.shape(t0)  # s3: t0.size(0), s4: t0.size(1)
s5 = prims.mul(s3, s4)        # s5: t0.size(0) * t0.size(1)

Even though s2 and s5 is viewed as different proxies, we should still resolve them to the same symbolic value relatively easily.

Re point 3. That's true for input tensors. Is it true for intermediates as well? I think this will be determined by how generally we are resolving things like broadcast and other shape checks for intermediates. I'm hoping that in cases where no shape-based control flow is executed, shape query would be limited to input tensors only and we do not have to lift the entire shape propagation into prologue.

I still like the idea of computing the shape symbolically in the prologue, but if you're into this approach then we may as well explore it further.

Yeah, I think this is necessary for constraints when we introduce those. We will touch that topic when we get there soon~ish.

what if the symbolic value of the length of a dimension of a tensor is used after the original tensor is out of scope? Would the query for the shape always precede this use and define the name properly in the resulting Python program?

if after all the transformations, there's still a prims.shape symbol in the graph, an executor would retrieve the tensor shape properties when executing the symbol and detach that number from the tensor itself.
At the execution phase, I think we would no longer need to associate these NumberProxy together. The association should only be of interest while we construct constraints.

@mruberry
Copy link
Collaborator

mruberry commented Sep 9, 2024

if after all the transformations, there's still a prims.shape symbol in the graph, an executor would retrieve the tensor shape properties when executing the symbol and detach that number from the tensor itself. At the execution phase, I think we would no longer need to associate these NumberProxy together. The association should only be of interest while we construct constraints.

OK, but during execution, then, wouldn't we have to pass the symbolic value as an argument to the computation function?

@t-vi
Copy link
Collaborator

t-vi commented Sep 9, 2024

I guess this is the part I'm most confused about. If we create a new NumberProxy from a shape query, then won't queries about that NumberProxy be difficult to combine into a set of queries for the actual shape?

So for me, concretely, I would be very interested in seeing

def flatten_more_or_less():
     return a.view(a.shape(0) * a.shape(1), a.shape(2) * a.shape(3))

work (or, more realistically, CausalSelfAttention, which includes a few reshape + permutes around the heads, query groups etc.).

As far as I understand v0 of symbolic constraints could be that we have number proxies for the input shapes and whenever we hit an executor saying "I want this proxy to be a constant instead of a number input" that we constrain all symbols in the expression in the prologue. Not sure if NVFuser would be that executor, or it can deal with it via it's own caching + recompilation (it would be interesting to know, though, which inputs could trigger re-compilation), but e.g. the cudgraphs transform has a caching / checking that would deal with it automatically.

@jjsjann123 jjsjann123 mentioned this pull request Sep 10, 2024
1 task
@jjsjann123
Copy link
Collaborator Author

I like how we are getting in-depth discussion on how we would want to support reshape in this PR. I think both approach that we discussed (lifting shape inference into prologue / leaving them in compute) would be able to support these workflow. The difference is about performance, which I'm not ready to answer yet.

And this PR by itself doesn't make that decision neither. I have issue to track the conversation we are having.

I think that's enough to unblock the continue review/merge of this PR for shape prim.

@t-vi
Copy link
Collaborator

t-vi commented Sep 10, 2024

From my POV, this PR is good to merge.

@jjsjann123
Copy link
Collaborator Author

FYI, I'm validating the performance on this PR with thunder/benchmark/targets.py. Since shape currently should be baked in as constants, I'm expecting no perf impact.

@t-vi
Copy link
Collaborator

t-vi commented Sep 10, 2024

Sounds great, let's merge tomorrow if the benchmarking does not find issues.

@jjsjann123
Copy link
Collaborator Author

In the benchmark result: https://gist.github.com/jjsjann123/11ecc9f1d0ddc53b6df389a525e35373

comparing this branch with the main it's based on, I'm not seeing any significant difference in the median time between the two. I think we are good to go. cc'ing @mruberry as well as @tfogal .

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good, unless anyone else objects.

Thank you @jjsjann123 @mruberry

@t-vi t-vi enabled auto-merge (squash) September 11, 2024 16:15
@t-vi t-vi merged commit 50f587d into main Sep 11, 2024
37 checks passed
@t-vi t-vi deleted the shape_prim branch September 11, 2024 16:15
@jjsjann123 jjsjann123 mentioned this pull request Sep 6, 2024
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants