Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling inplace through SSA #145

Open
t-vi opened this issue Apr 9, 2024 · 16 comments
Open

Handling inplace through SSA #145

t-vi opened this issue Apr 9, 2024 · 16 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@t-vi
Copy link
Collaborator

t-vi commented Apr 9, 2024

This issue is to facilitate discussion of inplace handling, namely the "big" solution of having a static single assignment (SSA) representation.

For any handling of inplace, we want to make certain that two things are achieved:

  • we don't want to take shortcuts that complicate passes by introducing the need to detect obstacles to optimizations, because it would harm usability and extensibility of Thunder.
  • we don't want to create ad-hoc band-aids to get things working that we would need to regress on later to introduce more proper handling because developing in the open more or less means no regressions.

Some thoughts from video/chat discussions:

About the problem:

  • The key difficulty in SSA is that we would need to keep track of which tensors get modified by an inplace update (i.e. which
    have memory that is to be updated), so we would need to know about views (the fancy term is alias analysis),
  • this is difficult for some things in PyTorch (i.e. reshape),
  • "assuming the worst" works to some extend.

Solution considerations:

  • Likely we would want inplace updates to have all affected tensors as outputs.
  • on inputs we would need to check for aliases as part of the prologue (maybe with a separate "assume aliasing is the OK" cache mode or sorts later),
  • operations need to know if their output is a view of their inputs (difficult for reshape, easy for most others),
  • initially, we would only check if tensors share storage,
  • likely the translation could be done in the interpretation phase,
  • we would need to have versioning / disambiguation of versions for tensor proxies during this, but not when we have the SSA.

Later versions could refine the alias analysis as needed.

@tfogal @mruberry @IvanYashchuk

@t-vi t-vi added enhancement New feature or request help wanted Extra attention is needed labels Apr 9, 2024
@mruberry
Copy link
Collaborator

mruberry commented Apr 9, 2024

One thing I've struggled with in these discussions is how does static single assignment form address the challenges of inplace operations? Doesn't Thunder's IR essentially already have the SSA property (I think there are some cases where we model operations which do nothing as the equivalent of x = x, but I don't think they're a problem)?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 9, 2024

Well, we do have SSA form in the Thunder IR until we allow inplace operations, which seems to be one of the things people want to do. The transformation to SSA is just there to make sure we don't run into correctness issues.

@mruberry
Copy link
Collaborator

mruberry commented Apr 9, 2024

Well, we do have SSA form in the Thunder IR until we allow inplace operations, which seems to be one of the things people want to do. The transformation to SSA is just there to make sure we don't run into correctness issues.

OK; I guess I think about this as "let's handle inplace operations while preserving our SSA form and relying on dataflow to determine the order of operations"?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 9, 2024

Exactly, or even "let's preserve our SSA form and the fact that dataflow describes the order of operations (and admissable reorderings) even when we want inplace".

@mruberry
Copy link
Collaborator

mruberry commented Apr 9, 2024

Maybe we can to be too conservative with operations like reshape and for ambiguous cases act like they (might have) created a view?

What gets a little tricky is that different executors may have different rules for when they do/don't create views. It's interesting to think about how different executors might communicate this.

Alternatively, maybe we should define semantics like, `if the reshaped tensor shares storage with another tensor, then the result must be a tensor that doesn't share its storage", and we can update the torch executor to respect those semantics?

@apaz-cli
Copy link
Contributor

apaz-cli commented Apr 9, 2024

How does static single assignment form address the challenges of inplace operations?

For the most part, it does so by assuming that these sorts of side effects don't exist. When you convert to SSA, you're meant to implement renaming a variable any time it's potentially modified, and hide any conditional logic, jumps, etc behind phi nodes. That's sort of the central idea of SSA. We have no control flow, so we don't need phi nodes. All that we need is a better way than the variable name for tracking the identity of the tensor.

So you would rename:

a = torch.zeros(5)
b = torch.ones(5)
a.add_(b)
return a

To something like:

t0 = torch.zeros(5)
t1 = torch.ones(5)
t2 = t0.add(t1)
return t2

All that's required to do that is to iterate through the trace, operation by operation, renaming every tensor that's modified. At the end though, you run into a problem. Suppose a was an argument. You need a way to assign back into it.

def foo(a: torch.Tensor):
  # Do other stuff
  b = torch.ones(5)
  a.add_(b)
  return a

becomes

def foo(t0: torch.Tensor):
  # Do other stuff
  t1 = torch.ones(5)
  t2 = t0.add(t1)
  tensor_memcpy(t0, t2) # to, from
  return t0

I'm not sure what we should do in this situation. NVFuser has it figured out. Just write the answer back inside. But we would have to do it with torch ops. If there's a way to perform this memcpy, ideally in a way that can be easily optimized out, let me know. In that case, I think writing a pass that functionalizes this stuff is pretty easy. If there isn't, I'm not sure how to do an SSA-style functionalization pass here.

To support this, we would only have to add either:
A) TensorProxies have an identity (which we can disregard after functionalization when the user is writing their passes)
B) Symbols contain a list of all the tensor references that they write to

Either would work.

@mruberry
Copy link
Collaborator

mruberry commented Apr 9, 2024

That's the thing about functionalizing -- it alone is not enough. If that's the approach that's taken then there must be a later pass that has information not present in the Python program which can update the memory as needed. The situation is also more complicated than a straightforward example, like the one above, suggests. If a tensor is written to inplace then how many tensors are modified? Maybe one, maybe two, maybe ten. How are these relationships expressed with a functionalization pass? I guess the operation would have to be repeated n times? And it's unclear if such programs could be optimized as well as programs that can express their inplace logic directly.

@IvanYashchuk
Copy link
Collaborator

As a reminder PyTorch returns new Tensors for inplace operations, the return isn't just None. So blocks of code like

a.add_(b)
return a

could equivalently be rewritten as

c = a.add_(b)
return c

and with this rewrite, maybe we don't need any special trace pass to reconstruct the dataflow?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 10, 2024

I am not sure that "there is a PyTorch programming style that makes it easy for us" helps us that much because silently producing incorrect results for valid PyTorch code unless users stick to some subset seems not a good option.

If you don't have better ideas, I would like to have

  a = torch.randn(4)
  b = a.view(2, 2)[1]
  b += 1

to translate to something along the lines of

   a = torch.randn(4)
   b = a.view(2, 2)[1]
   c = b + 1
   b(new) = copy_(b, c)
   a(new) = _update_storage_(a, b(new))

with b(new) and a(new) being new proxies. An optimizer can then take the copy_ and fuse it to the computation.
The main nuisances I see are

  • before reaching the SSA proxies will need to be versioned in a way to disambiguate the (new) part here,
  • the execution either needs to make sure that b is still a view into a or know which bits to copy to which bits of a,
  • strictly speaking, we would need to preserve the "is view into" property, but it is an implementation detail with reshape (I'd be inclined to treat reshape as always creating a new tensor for this aspect).

@apaz-cli
Copy link
Contributor

apaz-cli commented Apr 10, 2024

@IvanYashchuk yes, but a is c == True. It just returned a, and both python variables point to the same object.

The problem that I think @mruberry is referring to is the problem of views. In general, it's really difficult to know how many tensors are actually being modified by an inplace operation. Consider the following:

a = torch.zeros(5)
b = a[0:1]
c = a[1:3]
b[0] = 1
c[2] = 1
print(a is b) # False
print(a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()) # True
print(a) # [1, 0, 1, 0, 0]

How do we know when we write to the backing storage of b that we have to rename a? And can we find a way not to rename b when we write to c? It's a famous open problem in pytorch.

I'm... not really saying that I have a general solution for the "how many tensors am I actually changing when I do an inplace operation on a veiw" problem. But I do think that we can warn users when they do inplace operations on tensors that we know are a view. And I also think that the other cases are pretty easy. If they still aren't, let me know that I'm wrong.

I see that @t-vi just posted, and I think we independently came to largely the same conclusion. It should be easy enough to tag which tensors are views. So, we'd have to add two bits of info to meta functions. Or to tensors, or some combination. Either works.

@mruberry
Copy link
Collaborator

mruberry commented Apr 10, 2024

... I would like to have

  a = torch.randn(4)
  b = a.view(2, 2)[1]
  b += 1

to translate to something along the lines of

   a = torch.randn(4)
   b = a.view(2, 2)[1]
   c = b + 1
   b(new) = copy_(b, c)
   a(new) = _update_storage_(a, b(new))

with b(new) and a(new) being new proxies. An optimizer can then take the copy_ and fuse it to the computation. The main nuisances I see are

  • before reaching the SSA proxies will need to be versioned in a way to disambiguate the (new) part here,
  • the execution either needs to make sure that b is still a view into a or know which bits to copy to which bits of a,
  • strictly speaking, we would need to preserve the "is view into" property, but it is an implementation detail with reshape (I'd be inclined to treat reshape as always creating a new tensor for this aspect).

The thing I don't like about this approach is that there's nothing in the program that says a cannot be used after the creation of a(new), and if were to introduce such a concept then we'd have to add significant complexity to our existing dataflow-based passes.

A dataflow-based approach to ensure the correctness of these operations might look something like

a, a_storage0 = torch.randn(4)
b, _ = a.view(2, 2, storage=a_storage0)[1]  # maybe b, a_storage0 = would be clearer?
c, a_storage1 = torch.add_(b, 1, storage=a_storage0)

and then if someone performed an operation like a + 1 later it would look like

d, d_storage0 = torch.add(a, 1, storage=a_storage1)

Now obviously that's pretty ugly, and I think there's room for refinement, but if we have some mechanism for accepting and updating storage generations, like in the above, then you can order the operations through dataflow alone. In particular, the a + 1 operation could not be reordered before the inplace add.

I think what's tricky about this approach is thinking about how to best associate views and storage, but the tuple (view, storage) seems like a really good perspective to have.

Edit: we could probably make every current TensorProxy a tuple of view and storage so that calls like d = torch.add(a, 1) were really something like (d_view, d_storage0) = torch.add((a_view, a_storage1), 1) but we wouldn't have to make the split so prominent when printing the program (the tuples could have names)

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 25, 2024

Seems like #264 would also benefit from an SSA/functionalization pass, as it also deals with implicit state (except that it seems simpler in that we don't need worry about aliasing).

@jjsjann123
Copy link
Collaborator

Want to log one thing @mruberry mentioned in an offline discussion.

The scope of alias that thunder is trying to support would also include aliases across program inputs.

I think it makes it trickier to handle SSA, since we might not be able to reason how to replay inplace update on aliased inputs.

def foo(a, b):
  c = a.add_(1.0)
  e = b * 2
  return e

assuming a and b are aliases, SSA would need to replay the inplace update as something like this vvv first, (and then deSSA to write the update back to a.buffer)

  a0 = a.add(1.0)
  b0 = b.add(1.0)
  e = b0 * 2
  return e

But if a and b are just overlap, we wouldn't be able to replay a.add_(1.0) on b, unless we know how to model the overlap.

@mruberry
Copy link
Collaborator

mruberry commented May 7, 2024

Want to log one thing @mruberry mentioned in an offline discussion.

The scope of alias that thunder is trying to support would also include aliases across program inputs.

I think it makes it trickier to handle SSA, since we might not be able to reason how to replay inplace update on aliased inputs.

def foo(a, b):
  c = a.add_(1.0)
  e = b * 2
  return e

assuming a and b are aliases, SSA would need to replay the inplace update as something like this vvv first, (and then deSSA to write the update back to a.buffer)

  a0 = a.add(1.0)
  b0 = b.add(1.0)
  e = b0 * 2
  return e

But if a and b are just overlap, we wouldn't be able to replay a.add_(1.0) on b, unless we know how to model the overlap.

I think @t-vi had an interesting idea to have an instruction like

a0, b0 = update(a, b)

which could provide information about the aliasing relationships in the trace, and that might help address this?

@jjsjann123
Copy link
Collaborator

a0, b0 = update(a, b)

Yeah that's what we want to do. The question is that how do we exactly update b0.
With the wildest kind of memory overlap between a and b. It'll be pretty tricky trying to figure out how a.add_(1.0) would map to b's storage for that update.

@mruberry
Copy link
Collaborator

mruberry commented May 8, 2024

a0, b0 = update(a, b)

Yeah that's what we want to do. The question is that how do we exactly update b0. With the wildest kind of memory overlap between a and b. It'll be pretty tricky trying to figure out how a.add_(1.0) would map to b's storage for that update.

Agreed! I don't think this problem is, in general, solvable in time polynomial to the shape and strides of the tensors. If the inplace operations are explicitly represented I don't think we have that problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants