Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modeling of shape queries #1133

Open
1 task
jjsjann123 opened this issue Sep 10, 2024 · 9 comments
Open
1 task

Modeling of shape queries #1133

jjsjann123 opened this issue Sep 10, 2024 · 9 comments
Assignees
Labels
design This is a largish feature / design enhancement New feature or request

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Sep 10, 2024

🚀 Feature

Modeling of accessing shape attribute of Tensor/TensorProxy has raising up to discussion in separate PRs & discussions.

We are trying to debate on whether we would want to lift shape inference logic into prologue trace in Thunder in general.

i.e. For a program

   def foo(...):
      # assume we computed a tensor `t0`
      t1 = t0.reshape(t0.size(0), -1)

If we leave all shape logic in the computation, it should be simplified as below:

    # produce t0 from some earlier trace
    (i0, i1) = prims.shape(t0)

    i2 = prims.mul(1, i0)
    i3 = prims.mul(i2, i1)  # this is t0.numel

    i4 = prims.mul(1, i0)
    i5 = prims.div(i3, i4)  # this is the simplified logic in clang.reshape with `-1` in the entry
    t1 = prims.reshape(t0, (i0, i5))

One alternative is to lift all shape logic in the prologue trace, so we'll have

@prologue trace:
foo(...):
    # all the shape logic to compute i0 & i1 from original input

    i2 = prims.mul(1, i0)
    i3 = prims.mul(i2, i1)  # this is t0.numel

    i4 = prims.mul(1, i0)
    i5 = prims.div(i3, i4)  # this is the simplified logic in clang.reshape with `-1` in the entry  
    return (..., i0, i5, ...)  # NOTE: we are not necessarily passing i0 / i5, it could be any equivalent symbols.

@compute trace:
foo(..., i0, i5, ...):
    # compute t0
    t1 = prims.reshape(t0, (i0, i5))  # here we do not seeing that i0 is equivalent to t0.shape[0]

I think the first version where we see how the shape operation is defined in the operation would simplify generated kernel, since we do not need to resolve/validate reshape concretization.

  • Follow up with codegen example on the impact of shape operation vs opaque scalar reshape.
@jjsjann123 jjsjann123 added enhancement New feature or request design This is a largish feature / design labels Sep 10, 2024
@jjsjann123
Copy link
Collaborator Author

Logging offline suggestions by @mruberry:

  • It's OK to have prims.shape appear in the trace, which allows overlapping cpu and gpu computation, and if it's in the prologue that won't happen;

I think this justifies #1113 . cc'ing @t-vi

  • It'd be nice if the appearance of shape queries happened once per tensor it was needed for, at the global scope, and in a place that didn't disrupt fusions or require fusion executors like nvFuser to reason about the presence of the shape calls

I think this just requires a decent DCE to aggregate the shape query. Shape query on python is not trivial and takes ~ us.
Meanwhile, I'm uncertain if we would want to hide shape calls from fusion executors. We need to evaluate its impact on generated kernels as well as overhead of cache.

  • it should rely on direct comparisons of tuples of numbers and symbolic values to determine if two shapes are the same, and not try to infer they are the same because they have a common provenance

I agree that this should be the principle of how constraints are inserted. But if provenance can be used to simplify such constraints (e.g. equivalence, non-negative), I think those should still be leveraged.

@jjsjann123 jjsjann123 self-assigned this Sep 10, 2024
@t-vi
Copy link
Collaborator

t-vi commented Sep 10, 2024

So currently, the prologue performs exactly two things:

  • collecting tensors (and possibly soon other inputs) for the computation trace
  • checking things

I wonder if it would be good to keep things this way, in particular not having the prologue compute things for the computation trace. I'm not saying it should be this way or not, but just that if we change it, it should be a very deliberate choice.

@mruberry
Copy link
Collaborator

So currently, the prologue performs exactly two things:

  • collecting tensors (and possibly soon other inputs) for the computation trace
  • checking things

I wonder if it would be good to keep things this way, in particular not having the prologue compute things for the computation trace. I'm not saying it should be this way or not, but just that if we change it, it should be a very deliberate choice.

Some computations are so closely related to "checks" that it would reduce the total CPU work to compute them in the prologue.

In general I still like thinking about computing everything possible in the prologue. We have a lot of opportunities to optimize the performance of prologues. @jjsjann123 is completely correct that computing everything in the prologue does not take advantage of overlapping CPU and GPU computation, and it's interesting to see if we can do that effectively.

@t-vi
Copy link
Collaborator

t-vi commented Sep 10, 2024

Some computations are so closely related to "checks" that it would reduce the total CPU work to compute them in the prologue.

In general I still like thinking about computing everything possible in the prologue. We have a lot of opportunities to optimize the performance of prologues. @jjsjann123 is completely correct that computing everything in the prologue does not take advantage of overlapping CPU and GPU computation, and it's interesting to see if we can do that effectively.

I absolutely agree.

The key property we need to keep here is "thinness" of Thunder. By this I mean that if the user of thunder knows they're calling the same computation 50x with controlled inputs (e.g. in the training loop), they can run the prologue once and then rely on the compute function working for them.

Note that if we anticipate power users of Thunder to do this, we also save them computation by putting it in the prologue, which might be even more attractive than overlapping with GPU.

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Dec 6, 2024

Want to highlight the issue we see here as well.

@jjsjann123
Copy link
Collaborator Author

I think our current shape modeling is a bit too explicit. Any shape query leaves its own prim in the trace and we don't have a clean up path to handle that.

Taking the example of broadcast.

import thunder     
import torch
dtype = torch.float32
 
def foo(a, b):
    return a + b
 
jfoo = thunder.jit(foo, cache="symbolic values")
#jfoo = thunder.jit(foo)
 
a = torch.randn(2, 2, device="cuda")
b = torch.randn(1, device="cuda")
 
out = jfoo(a, b)

The trace right after the interpreter, there's lots of duplicated prims.shape(...) on each tensor in the trace. Which obviously needs to be cleaned up.

@torch.no_grad()
@no_autocast
# No signature available
  # /volume/thunder_dynamic/broadcast.py:6:         return a + b
  t15 = ltorch.add(t_0, t_1, alpha=1)  # t15: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # (i0, i1) = prims.shape(t_0)
    # (i0, i1) = prims.shape(t_0)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b3 = prims.eq(i1, 1)  # b3: "bool False"
    # b4 = prims.eq(i1, i1)  # b4: "bool True"
    # b5 = prims.eq(i0, 1)  # b5: "bool False"
    # b6 = prims.eq(i0, i0)  # b6: "bool True"
    # b7 = prims.eq(i1, 1)  # b7: "bool False"
    # b8 = prims.eq(i2, 1)  # b8: "bool True"
    # (i0, i1) = prims.shape(t_0)
    # (i0, i1) = prims.shape(t_0)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b9 = prims.eq(i2, i0)  # b9: "bool False"
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b10 = prims.eq(i1, i2)  # b10: "bool False"
    # b11 = prims.eq(i2, 1)  # b11: "bool True"
    # b12 = prims.ne(i1, -1)  # b12: "bool True"
    # (i2,) = prims.shape(t_1)
    # t14 = prims.broadcast_in_dim(t_1, (i0, i1), (1,))  # t14: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # t15 = prims.add(t_0, t14)  # t15: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  return t15

There's also another duplication happening in prologue trace

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 2)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  a: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]" = args[0]
  b: "cuda:0 f32[[IntegerProxy name=i2, value=1, static=CONSTRAINT.CONSTRAINABLE]]" = args[1]
  (i0, i1) = prims.shape(a)
  subscr: "Any" = args[0]
  obj: "Any" = subscr.shape
  i0: "int 2" = obj[0]
  i1: "int 2" = obj[1]
  (i0, i1) = prims.shape(a)
  prims.check_tensor_shape_and_metadata(a, (i0, i1), 'cuda:0', torch.float32, False)
  (i2,) = prims.shape(b)
  p0: "Any" = args[1]
  p1: "Any" = p0.shape
  i2: "int 1" = p1[0]
  (i2,) = prims.shape(b)
  prims.check_tensor_shape_and_metadata(b, (i2,), 'cuda:0', torch.float32, False)
  # ...
  return ((a, b), ())

We have lifted up the prims.shape logic here

  a = args[0]
  (i0, i1) = prims.shape(a)

there's also these pieces doing the same thing in prologue trace

  subscr: "Any" = args[0]
  obj: "Any" = subscr.shape
  i0: "int 2" = obj[0]
  i1: "int 2" = obj[1]

which might be slightly harder to detect and clean up.

@mruberry
Copy link
Collaborator

mruberry commented Dec 6, 2024

I think our current shape modeling is a bit too explicit. Any shape query leaves its own prim in the trace and we don't have a clean up path to handle that.

Taking the example of broadcast.

import thunder     
import torch
dtype = torch.float32
 
def foo(a, b):
    return a + b
 
jfoo = thunder.jit(foo, cache="symbolic values")
#jfoo = thunder.jit(foo)
 
a = torch.randn(2, 2, device="cuda")
b = torch.randn(1, device="cuda")
 
out = jfoo(a, b)

The trace right after the interpreter, there's lots of duplicated prims.shape(...) on each tensor in the trace. Which obviously needs to be cleaned up.

@torch.no_grad()
@no_autocast
# No signature available
  # /volume/thunder_dynamic/broadcast.py:6:         return a + b
  t15 = ltorch.add(t_0, t_1, alpha=1)  # t15: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # (i0, i1) = prims.shape(t_0)
    # (i0, i1) = prims.shape(t_0)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b3 = prims.eq(i1, 1)  # b3: "bool False"
    # b4 = prims.eq(i1, i1)  # b4: "bool True"
    # b5 = prims.eq(i0, 1)  # b5: "bool False"
    # b6 = prims.eq(i0, i0)  # b6: "bool True"
    # b7 = prims.eq(i1, 1)  # b7: "bool False"
    # b8 = prims.eq(i2, 1)  # b8: "bool True"
    # (i0, i1) = prims.shape(t_0)
    # (i0, i1) = prims.shape(t_0)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b9 = prims.eq(i2, i0)  # b9: "bool False"
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # (i2,) = prims.shape(t_1)
    # b10 = prims.eq(i1, i2)  # b10: "bool False"
    # b11 = prims.eq(i2, 1)  # b11: "bool True"
    # b12 = prims.ne(i1, -1)  # b12: "bool True"
    # (i2,) = prims.shape(t_1)
    # t14 = prims.broadcast_in_dim(t_1, (i0, i1), (1,))  # t14: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
    # t15 = prims.add(t_0, t14)  # t15: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  return t15

There's also another duplication happening in prologue trace

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 2)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  a: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]" = args[0]
  b: "cuda:0 f32[[IntegerProxy name=i2, value=1, static=CONSTRAINT.CONSTRAINABLE]]" = args[1]
  (i0, i1) = prims.shape(a)
  subscr: "Any" = args[0]
  obj: "Any" = subscr.shape
  i0: "int 2" = obj[0]
  i1: "int 2" = obj[1]
  (i0, i1) = prims.shape(a)
  prims.check_tensor_shape_and_metadata(a, (i0, i1), 'cuda:0', torch.float32, False)
  (i2,) = prims.shape(b)
  p0: "Any" = args[1]
  p1: "Any" = p0.shape
  i2: "int 1" = p1[0]
  (i2,) = prims.shape(b)
  prims.check_tensor_shape_and_metadata(b, (i2,), 'cuda:0', torch.float32, False)
  # ...
  return ((a, b), ())

We have lifted up the prims.shape logic here

  a = args[0]
  (i0, i1) = prims.shape(a)

there's also these pieces doing the same thing in prologue trace

  subscr: "Any" = args[0]
  obj: "Any" = subscr.shape
  i0: "int 2" = obj[0]
  i1: "int 2" = obj[1]

which might be slightly harder to detect and clean up.

Let me know if you want to chat about some ideas! Splatting the shape information once could be a fix. Another option could be to update the CSE pass. I'm a little surprised it doesn't work already.

@jjsjann123
Copy link
Collaborator Author

Another option could be to update the CSE pass. I'm a little surprised it doesn't work already.

CSE only works at the top level, so the repetitive pattern inside the lower hierarchy isn't getting cleaned up, which is an eyesore.

@t-vi
Copy link
Collaborator

t-vi commented Dec 6, 2024

Note that #1500 added remove_duplicate_number_proxies to DCE because we prefer to have a single producer for proxies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design This is a largish feature / design enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants