-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PR #1110 nearly doubles the compilation & execution time of a copy-heavy program (safely creating kernels with aliasing) #1173
Comments
Apparently, Val* set(Val* v) {
Val* out = ops::newValLike(v, v->getDataType().value());
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, v);
return out;
}
|
@jjsjann123 could you take a look? |
fd.add_output(A, B)
fd.add_output(C, D) Consider the above fusion definition. If I understand correctly, #1110 intends to avoid one tensor appearing twice, i.e. (A or B) and (C or D) aliasing each other. I suspect that, when there is no direct call of def f(a, b):
a.add_(b) Its trace is def computation(a, b):
# a: "cuda:0 f32[]"
# b: "cuda:0 f32[]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78: a.add_(b)
t1 = ltorch.add_(a, b, alpha=None) # t1: "cuda:0 f32[]"
# t0 = ltorch.add(a, b, alpha=None) # t0: "cuda:0 f32[]"
# t0 = prims.add(a, b) # t0: "cuda:0 f32[]"
# t1 = prims.copy_(t0, a) # t1: "cuda:0 f32[]"
return {'output': None, 'flat_args': [a, b]}
# Constructed by Dead Code Elimination (took 0 milliseconds)
def computation(a, b):
# a: "cuda:0 f32[]"
# b: "cuda:0 f32[]"
# Functionalized from `t1 = add_(a,b,None)`
t0 = ltorch.add(a, b, alpha=None) # t0: "cuda:0 f32[]"
# t0 = prims.add(a, b) # t0: "cuda:0 f32[]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78: a.add_(b)
t1 = prims.copy_(t0, a) # t1: "cuda:0 f32[]"
return {'output': None, 'flat_args': [t1, b]} The LHS of Nonetheless, we must pay attention to the direct call of def f(a, b, c):
c.copy_(a)
c.copy_(b) So I suggest applying #1110's fix only to plain @crcrpar Would you check if this is correct? |
This could be a good opportunity to prepare a better construct for |
rel: #1084 |
Another solution is to make nvFuserExecutor track all the arguments to |
That's a surprising regression. Wondering is you can give some repro scripts so I can investigate what's causing the regression. cc'ing @shino16 For the question re:
The PR adding |
This isn't a question of nvFuser of increasing compilation time, it is of what is proper to do for aliasing. nvFuser would not be increasing compilation due to segmentation if there was not a question of aliasing and safely create a functional set of kernels. Jie said this :
|
So the nvfuser regression somewhat make sense. This is actually a case that shows the fix in PR #1110 is doing the right thing. i.e. reverting that PR changes the program semantics slightly. The pattern I'm seeing in the example is that:
And before the patch we have something like
The difference here is that how many outputs are generated from the kernel
|
^^^ That explains the execution time regression. The compilation time regression is a bit strange... I'm seeing 4 kernels instead of 2 kernels compiled after the change. Given that the original issue was coming before we fix the IO race through segmentation, I'm not totally sure why. I'll follow up with a smaller repro in nvfuser to track that. |
One last note, the repro here: #1173 (comment) can be greatly simplified by replace the real model with a simple linear layer. |
Did a further dig on the nvfuser regression, the observed regression is coming from an issue on mark alias pass. compilation time: 66s, execution time: 13.15ms On nvfuser branch Note this isn't patching the real issue here, which is that thunder is not giving us the right fusion definition. |
unassigned myself since the root cause for this regression should be addressed by #1209 |
@jjsjann123 - Is this done? |
I'm marking this issue done, since the segmentation on nvfuser side from alilas analysis is patched and the regression is coming from the thunder representation. on ToT right now with nvfuser commit bb058595c49dc32416d563f5a4c1c5f22a01ca54 I'm seeing one nvfuser region as below:
You can see the representation here:
We are requesting the t7 as the output from nvFusion, looking at the definition here, t7 shouldn't be aliased to argx. #1110 correctly fixed that. I'm closing this issue as not planned, since optimizer is not of high priority at this moment. |
I took a benchmark of
torch.compile
'dAdam.step
with Thunder backend. Surprisingly, the compiledAdam.step
was even slower than the eager mode.torch.compile(adam.step)
torch.compile(adam.step, backend=thunder)
torch.compile(adam.step, backend=thunder)
, #1110 reverted#1110 adds
fd.ops.set
on everyprims.copy_
(diff). What isfd.ops.set
? Can we avoid using this op?cc @tfogal
The text was updated successfully, but these errors were encountered: