Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorProxy.shape should be unpacked automatically #1253

Closed
jjsjann123 opened this issue Oct 3, 2024 · 3 comments
Closed

TensorProxy.shape should be unpacked automatically #1253

jjsjann123 opened this issue Oct 3, 2024 · 3 comments
Assignees

Comments

@jjsjann123
Copy link
Collaborator

🐛 Bug

TensorProxy.shape remains as an attribute, hence accessing it won't leave an unpack in trace. This causes issues when we have NumberProxy in TensorProxy.shape.

In #1201 commit 26f883e. I have to rely on this hack. Otherwise, grad transform would see an invalid trace,

e.g. in a trivial slice:

def foo(a):
  return a[..., : a.shape[-1]]

thunder.jit(foo, cache="symbolic values")
def computation(a, i1):
  # a: "cpu f32 [IntegerProxy name=i0], [IntegerProxy name=i1]
  # i1: "int 8"
  a = ltorch.getitem(a, (..., slice(None, i1, None))
      # (i0, i1) = prims.shape(a)  # THIS IS WHAT THE HACK DOES
      # b2 = prims.lt(i0, 0)
      # ...

Without the explicit unpack of a.shape, the subsymbols in ltorch.getitem would access i0, which is implicitly carried by a.shape but not explicitly in the trace.

Alternative

This problem can also be properly resolved in prologue trace. i.e. here i1 is unpacked in prologue, because it is consumed by the top level symbol ltorch.getitem. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, so i0 isn't getting unpacked in prologue yet.

So for input TensorProxy, I think prologue unpacking is the right choice here. For intermediate tensor, it might be a mixed solution, which goes back to the conversation we have in #1133 .

@jjsjann123 jjsjann123 self-assigned this Oct 3, 2024
@t-vi
Copy link
Collaborator

t-vi commented Oct 3, 2024

This problem can also be properly resolved in prologue trace. i.e. here i1 is unpacked in prologue, because it is consumed by the top level symbol ltorch.getitem. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, so i0 isn't getting unpacked in prologue yet.

I'm a bit skeptical about the alternative here and my gut feeling is that the main solution (to unpack the shape "close" to where it is used is preferable).
To my mind, the alternative solution implies the major change that is that the subsymbols considered as a block of code has inputs that the symbol has not. To my mind, this is tricky on several layers (producers / consumers etc.).

The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.

@jjsjann123
Copy link
Collaborator Author

sorry missed this email earlier (gmail access is limited at this moment, depending on how reliable vpn is).
I'm working on the automatic unpacking in #1260 .

The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.

I'm already hitting this one. In #1260 , every shape query resulted in a prims.shape in the trace. And since the primitive returns the NumberProxy carried by the tensor, we are assigning a given symbol multiple times in the trace breaking SSA.

e.g. with the following program:

def foo(a):
  return torch.reshape(a, [a.numel()]).relu()

We have a trace:

def computation(a, i0, i1):
  i2 = operator.mul(1, i0)
  i3 = operator.mul(i2, i1)
  # ...
  t11 = torch.reshape(a, [i3])
  # ...
  t20 = torch.nn.functional.relu(t11, False)
  # ...

I'm seeing a couple issues here:

  1. We could have multiple identical unpacking here in subsymbol. e.g. in the decomposition of relu, we have prims.shape(t11) recorded multiple times. This could also happen when multiple TensorProxy with the same shape being queried. I think this can be cleaned up with a CSE/DCE through subsymbols so we'll just keep a single query.
  # t39 = ltorch.gt(t11, 0)
    # (i3,) = prims.shape(t11)
    # (i3,) = prims.shape(t11)
    # ...
    # t39 = prims.gt(t11, 0.0)
  1. We have unpacking of # (i0, i1) = prims.shape(a) in the subsymbol, which is already unpacked in prologue. This could also happen for shape queries at different level. These won't be an issue until we flatten the symbol. I think this could be something that we just leave for the code that calls flattening to handle (like fusion pass?)
  t11 = torch.reshape(a, [i3])
    # t11 = ltorch.reshape(a, [i3])
      # (i0, i1) = prims.shape(a)
      # b14 = prims.eq(i0, i3)
      # b15 = prims.ge(i3, 0)
      # i18 = prims.mul(1, i3)
      # t11 = prims.reshape(a, (i3,))

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Dec 6, 2024

I'm closing this one for now since #1260 is merged now.

For the duplicated shape query, I'll leave that for #1133 to track.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants