-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handling inplace through SSA #145
Comments
One thing I've struggled with in these discussions is how does static single assignment form address the challenges of inplace operations? Doesn't Thunder's IR essentially already have the SSA property (I think there are some cases where we model operations which do nothing as the equivalent of |
Well, we do have SSA form in the Thunder IR until we allow inplace operations, which seems to be one of the things people want to do. The transformation to SSA is just there to make sure we don't run into correctness issues. |
OK; I guess I think about this as "let's handle inplace operations while preserving our SSA form and relying on dataflow to determine the order of operations"? |
Exactly, or even "let's preserve our SSA form and the fact that dataflow describes the order of operations (and admissable reorderings) even when we want inplace". |
Maybe we can to be too conservative with operations like What gets a little tricky is that different executors may have different rules for when they do/don't create views. It's interesting to think about how different executors might communicate this. Alternatively, maybe we should define semantics like, `if the reshaped tensor shares storage with another tensor, then the result must be a tensor that doesn't share its storage", and we can update the torch executor to respect those semantics? |
For the most part, it does so by assuming that these sorts of side effects don't exist. When you convert to SSA, you're meant to implement renaming a variable any time it's potentially modified, and hide any conditional logic, jumps, etc behind phi nodes. That's sort of the central idea of SSA. We have no control flow, so we don't need phi nodes. All that we need is a better way than the variable name for tracking the identity of the tensor. So you would rename: a = torch.zeros(5)
b = torch.ones(5)
a.add_(b)
return a To something like: t0 = torch.zeros(5)
t1 = torch.ones(5)
t2 = t0.add(t1)
return t2 All that's required to do that is to iterate through the trace, operation by operation, renaming every tensor that's modified. At the end though, you run into a problem. Suppose def foo(a: torch.Tensor):
# Do other stuff
b = torch.ones(5)
a.add_(b)
return a becomes def foo(t0: torch.Tensor):
# Do other stuff
t1 = torch.ones(5)
t2 = t0.add(t1)
tensor_memcpy(t0, t2) # to, from
return t0 I'm not sure what we should do in this situation. NVFuser has it figured out. Just write the answer back inside. But we would have to do it with torch ops. If there's a way to perform this memcpy, ideally in a way that can be easily optimized out, let me know. In that case, I think writing a pass that functionalizes this stuff is pretty easy. If there isn't, I'm not sure how to do an SSA-style functionalization pass here. To support this, we would only have to add either: Either would work. |
That's the thing about functionalizing -- it alone is not enough. If that's the approach that's taken then there must be a later pass that has information not present in the Python program which can update the memory as needed. The situation is also more complicated than a straightforward example, like the one above, suggests. If a tensor is written to inplace then how many tensors are modified? Maybe one, maybe two, maybe ten. How are these relationships expressed with a functionalization pass? I guess the operation would have to be repeated n times? And it's unclear if such programs could be optimized as well as programs that can express their inplace logic directly. |
As a reminder PyTorch returns new Tensors for inplace operations, the return isn't just None. So blocks of code like a.add_(b)
return a could equivalently be rewritten as c = a.add_(b)
return c and with this rewrite, maybe we don't need any special trace pass to reconstruct the dataflow? |
I am not sure that "there is a PyTorch programming style that makes it easy for us" helps us that much because silently producing incorrect results for valid PyTorch code unless users stick to some subset seems not a good option. If you don't have better ideas, I would like to have a = torch.randn(4)
b = a.view(2, 2)[1]
b += 1 to translate to something along the lines of a = torch.randn(4)
b = a.view(2, 2)[1]
c = b + 1
b(new) = copy_(b, c)
a(new) = _update_storage_(a, b(new)) with b(new) and a(new) being new proxies. An optimizer can then take the
|
@IvanYashchuk yes, but The problem that I think @mruberry is referring to is the problem of views. In general, it's really difficult to know how many tensors are actually being modified by an inplace operation. Consider the following: a = torch.zeros(5)
b = a[0:1]
c = a[1:3]
b[0] = 1
c[2] = 1
print(a is b) # False
print(a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()) # True
print(a) # [1, 0, 1, 0, 0] How do we know when we write to the backing storage of I'm... not really saying that I have a general solution for the "how many tensors am I actually changing when I do an inplace operation on a veiw" problem. But I do think that we can warn users when they do inplace operations on tensors that we know are a view. And I also think that the other cases are pretty easy. If they still aren't, let me know that I'm wrong. I see that @t-vi just posted, and I think we independently came to largely the same conclusion. It should be easy enough to tag which tensors are views. So, we'd have to add two bits of info to meta functions. Or to tensors, or some combination. Either works. |
The thing I don't like about this approach is that there's nothing in the program that says A dataflow-based approach to ensure the correctness of these operations might look something like a, a_storage0 = torch.randn(4)
b, _ = a.view(2, 2, storage=a_storage0)[1] # maybe b, a_storage0 = would be clearer?
c, a_storage1 = torch.add_(b, 1, storage=a_storage0) and then if someone performed an operation like
Now obviously that's pretty ugly, and I think there's room for refinement, but if we have some mechanism for accepting and updating storage generations, like in the above, then you can order the operations through dataflow alone. In particular, the I think what's tricky about this approach is thinking about how to best associate views and storage, but the tuple (view, storage) seems like a really good perspective to have. Edit: we could probably make every current TensorProxy a tuple of view and storage so that calls like |
Seems like #264 would also benefit from an SSA/functionalization pass, as it also deals with implicit state (except that it seems simpler in that we don't need worry about aliasing). |
Want to log one thing @mruberry mentioned in an offline discussion. The scope of I think it makes it trickier to handle SSA, since we might not be able to reason how to replay inplace update on aliased inputs.
assuming a and b are aliases, SSA would need to replay the inplace update as something like this vvv first, (and then deSSA to write the update back to a.buffer)
But if a and b are just overlap, we wouldn't be able to replay |
I think @t-vi had an interesting idea to have an instruction like
which could provide information about the aliasing relationships in the trace, and that might help address this? |
Yeah that's what we want to do. The question is that how do we exactly update |
Agreed! I don't think this problem is, in general, solvable in time polynomial to the shape and strides of the tensors. If the inplace operations are explicitly represented I don't think we have that problem. |
This issue is to facilitate discussion of inplace handling, namely the "big" solution of having a static single assignment (SSA) representation.
For any handling of inplace, we want to make certain that two things are achieved:
Some thoughts from video/chat discussions:
About the problem:
have memory that is to be updated), so we would need to know about views (the fancy term is alias analysis),
reshape
),Solution considerations:
reshape
, easy for most others),Later versions could refine the alias analysis as needed.
@tfogal @mruberry @IvanYashchuk
The text was updated successfully, but these errors were encountered: