Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #1110 nearly doubles the compilation & execution time of a copy-heavy program (safely creating kernels with aliasing) #1173

Closed
shino16 opened this issue Sep 19, 2024 · 16 comments
Assignees
Labels
enhancement New feature or request nvfuser thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@shino16
Copy link
Contributor

shino16 commented Sep 19, 2024

I took a benchmark of torch.compile'd Adam.step with Thunder backend. Surprisingly, the compiled Adam.step was even slower than the eager mode.

compilation (s) execution (ms)
eager 0.0 10.87
torch.compile(adam.step) 21.0 5.83
torch.compile(adam.step, backend=thunder) 95.0 11.59
torch.compile(adam.step, backend=thunder), #1110 reverted 46.0 6.24

#1110 adds fd.ops.set on every prims.copy_ (diff). What is fd.ops.set? Can we avoid using this op?

cc @tfogal

@shino16 shino16 added the enhancement New feature or request label Sep 19, 2024
@shino16
Copy link
Contributor Author

shino16 commented Sep 20, 2024

Apparently, fd.ops.set creates an intermediate no-op node that binds what it receives to another TensorView.

https://github.com/NVIDIA/Fuser/blob/aad728668853e4302c00d8e98a50f13cc5db4184/csrc/ops/alias.cpp#L19-L23

Val* set(Val* v) {
  Val* out = ops::newValLike(v, v->getDataType().value());
  IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, v);
  return out;
}

IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, v); is used in other places, e.g. when fd.ops.expand does not have to expand anything. link So fd.ops.set does not necessarily make a copy, but just bind a separate node to the output of another.

@crcrpar
Copy link
Collaborator

crcrpar commented Sep 20, 2024

@jjsjann123 could you take a look?

@shino16
Copy link
Contributor Author

shino16 commented Sep 20, 2024

fd.add_output(A, B)
fd.add_output(C, D)

Consider the above fusion definition. If I understand correctly, #1110 intends to avoid one tensor appearing twice, i.e. (A or B) and (C or D) aliasing each other.

I suspect that, when there is no direct call of Tensor.copy_ and all the other in-place ops are functionalized, this aliasing never happens even without #1110. For example, consider

def f(a, b):
  a.add_(b)

Its trace is

def computation(a, b):
  # a: "cuda:0 f32[]"
  # b: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78:          a.add_(b)
  t1 = ltorch.add_(a, b, alpha=None)  # t1: "cuda:0 f32[]"
    # t0 = ltorch.add(a, b, alpha=None)  # t0: "cuda:0 f32[]"
      # t0 = prims.add(a, b)  # t0: "cuda:0 f32[]"
    # t1 = prims.copy_(t0, a)  # t1: "cuda:0 f32[]"
  return {'output': None, 'flat_args': [a, b]}

# Constructed by Dead Code Elimination (took 0 milliseconds)
def computation(a, b):
  # a: "cuda:0 f32[]"
  # b: "cuda:0 f32[]"
  # Functionalized from `t1 = add_(a,b,None)`
  t0 = ltorch.add(a, b, alpha=None)  # t0: "cuda:0 f32[]"
    # t0 = prims.add(a, b)  # t0: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78:          a.add_(b)
  t1 = prims.copy_(t0, a)  # t1: "cuda:0 f32[]"
  return {'output': None, 'flat_args': [t1, b]}

The LHS of copy_, t0, is the output of the arithmetic op just before prims.copy_. As long as the preceding arithmetic op does not produce an alias, and the dst of copies do not alias (which is guaranteed by #798), we never have unwanted aliases.

Nonetheless, we must pay attention to the direct call of Tensor.copy_, without a preceding arithmetic op. We could write

def f(a, b, c):
  c.copy_(a)
  c.copy_(b)

So I suggest applying #1110's fix only to plain Tensor.copy_. As Adam.step does not involve direct copy_, this will speed up the Thunder-jitted Adam.step to the original efficiency.

@crcrpar Would you check if this is correct?

@shino16
Copy link
Contributor Author

shino16 commented Sep 20, 2024

This could be a good opportunity to prepare a better construct for Tensor.copy_. Currently, it is directly translated into prims.copy_, but Tensor.copy_ allows broadcast, dtype cast and device transfer.

@crcrpar
Copy link
Collaborator

crcrpar commented Sep 20, 2024

but Tensor.copy_ allows broadcast, dtype cast and device transfer.

rel: #1084

@shino16
Copy link
Contributor Author

shino16 commented Sep 20, 2024

Another solution is to make nvFuserExecutor track all the arguments to add_output and place fd.ops.set if needed. This is more work but more general and robust.

@jjsjann123
Copy link
Collaborator

compilation (s) execution (ms)
eager 0.0 10.87
torch.compile(adam.step) 21.0 5.83
torch.compile(adam.step, backend=thunder) 95.0 11.59
torch.compile(adam.step, backend=thunder), #1110 reverted 46.0 6.24

#1110 adds fd.ops.set on every prims.copy_ (diff). What is fd.ops.set? Can we avoid using this op?

That's a surprising regression. Wondering is you can give some repro scripts so I can investigate what's causing the regression. cc'ing @shino16

For the question re: fd.ops.set: Out of curiosity, does #798 also detect cases when we are returning t0? vvv

  t0 = ltorch.add(a, b, alpha=None)  # t0: "cuda:0 f32[]"
  t1 = prims.copy_(t0, a)  # t1: "cuda:0 f32[]"

The PR adding fd.ops.set is trying to avoid returning aliases. That seems to be the simple solution to avoid issues caused by such side effects when we have multiple aliases sharing storage.
I tend to agree with @t-vi on his suggestion in #1177, we should have a precise way to specify inplace update vs copying. That maybe would make it easier on nvfuser executor to construct a simpler fusion.

@tfogal tfogal added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Sep 20, 2024
@shino16
Copy link
Contributor Author

shino16 commented Sep 21, 2024

Wondering is you can give some repro scripts so I can investigate what's causing the regression. cc'ing @shino16

I should have done so. The benchmark script is on this gist. You can try reverting 7c9cd8c.

And thank you for your opinion about implementation!

@jjsjann123 jjsjann123 self-assigned this Sep 26, 2024
@kevinstephano kevinstephano changed the title PR #1110 nearly doubles the compilation & execution time of a copy-heavy program PR #1110 nearly doubles the compilation & execution time of a copy-heavy program (safely creating kernels with aliasing) Oct 14, 2024
@kevinstephano
Copy link
Collaborator

kevinstephano commented Oct 14, 2024

This isn't a question of nvFuser of increasing compilation time, it is of what is proper to do for aliasing. nvFuser would not be increasing compilation due to segmentation if there was not a question of aliasing and safely create a functional set of kernels.

Jie said this :

I tend to agree with @t-vi on his suggestion in #1177, we should have a precise way to specify inplace update vs copying. That maybe would make it easier on nvfuser executor to construct a simpler fusion.

@crcrpar crcrpar assigned crcrpar and unassigned shino16 Oct 15, 2024
@jjsjann123
Copy link
Collaborator

So the nvfuser regression somewhat make sense.

This is actually a case that shows the fix in PR #1110 is doing the right thing. i.e. reverting that PR changes the program semantics slightly.

The pattern I'm seeing in the example is that:

    T0 = fd.define_tensor(shape=[3, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[3, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    # ...
    T7 = fd.ops.lerp(T0, T1, S6)
    T16 = fd.ops.set(T7)
    fd.add_output(T16, T0)  # Create an alias output T16 to T0. This update T0 in-place
    # ...
    fd.add_output(T7)  # Create another output T7 and copy the value out.

And before the patch we have something like

    T0 = fd.define_tensor(shape=[3, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[3, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    # ...
    T7 = fd.ops.lerp(T0, T1, S6)
    fd.add_output(T7, T0)  # Create an alias output T7 to T0. This update T0 in-place
    # ...
    fd.add_output(T7)  # annotate that T7 is a semantic output, so we return the aliased tensor T0.

The difference here is that how many outputs are generated from the kernel
The patch would create an extra output buffer that's not aliasing any inputs. If we look at the thunder trace here vvv

  [t2, l_self_state_list_l_self_state_keys_0_step_, t29, addcdiv_] = nvFusion0(L_self_state_list_L_self_state_keys_0_exp_avg_, L_self_param_groups_0_params_0_grad, L_self_state_list_L_self_state_keys_0_step_, L_self_state_list_L_self_state_keys_0_exp_avg_sq_, conj, L_self_param_groups_0_params_0_)
    # t2 = prims.lerp(L_self_state_list_L_self_state_keys_0_exp_avg_, L_self_param_groups_0_params_0_grad, 0.09999999999999998)  # t2: "cuda:0 f32[3, 2]"
    # prims.copy_(t2, L_self_state_list_L_self_state_keys_0_exp_avg_)

t2 is not an alias to anything and the compiled program shouldn't be returning t2 as an alias.

@jjsjann123
Copy link
Collaborator

^^^ That explains the execution time regression.

The compilation time regression is a bit strange... I'm seeing 4 kernels instead of 2 kernels compiled after the change. Given that the original issue was coming before we fix the IO race through segmentation, I'm not totally sure why. I'll follow up with a smaller repro in nvfuser to track that.

@jjsjann123
Copy link
Collaborator

One last note, the repro here: #1173 (comment) can be greatly simplified by replace the real model with a simple linear layer.

@jjsjann123
Copy link
Collaborator

Did a further dig on the nvfuser regression, the observed regression is coming from an issue on mark alias pass.

compilation time: 66s, execution time: 13.15ms
with 1110 reverted:
compilation time: 44s, execution time: 8.59ms

On nvfuser branch
compilation time: 47s, execution time: 11.03ms.

Note this isn't patching the real issue here, which is that thunder is not giving us the right fusion definition.

@jjsjann123
Copy link
Collaborator

unassigned myself since the root cause for this regression should be addressed by #1209

@nvMelissa
Copy link
Collaborator

@jjsjann123 - Is this done?

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Nov 25, 2024

I'm marking this issue done, since the segmentation on nvfuser side from alilas analysis is patched and the regression is coming from the thunder representation.

on ToT right now with nvfuser commit bb058595c49dc32416d563f5a4c1c5f22a01ca54
thunder commit: 81f83f3

I'm seeing one nvfuser region as below:

  [l_self_state_list_l_self_state_keys_0_list_l_self_state_list_l_self_state_keys_0_keys_0_, t7, t35, addcdiv_] = nvFusion0(L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_0_, item_1, item_2, L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_1_, L_self_param_groups_0_params_0_grad, sub, L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_2_, conj, sub_1, item, item_3, L_self_param_groups_0_params_0_)
    # t4 = prims.add(L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_0_, 1.0)  # t4: "cuda:0 f32[]"
    # pow_1 = prims.pow(item_1, t4)  # pow_1: "cuda:0 f32[]"
    # pow_2 = prims.pow(item_2, t4)  # pow_2: "cuda:0 f32[]"
    # l_self_state_list_l_self_state_keys_0_list_l_self_state_list_l_self_state_keys_0_keys_0_ = prims.copy_(t4, L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_0_)  # l_self_state_list_l_self_state_keys_0_list_l_self_state_list_l_self_state_keys_0_keys_0_: "cuda:0 f32[]"
    # t7 = prims.lerp(L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_1_, L_self_param_groups_0_params_0_grad, sub)  # t7: "cuda:0 f32[3, 2]"
    # t9 = prims.mul(L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_2_, item_2)  # t9: "cuda:0 f32[3, 2]"
    # mul = prims.mul(L_self_param_groups_0_params_0_grad, conj)  # mul: "cuda:0 f32[3, 2]"
    # mul_1 = prims.mul(mul, sub_1)  # mul_1: "cuda:0 f32[3, 2]"
    # t15 = prims.add(t9, mul_1)  # t15: "cuda:0 f32[3, 2]"
    # bias_correction1 = prims.sub(1.0, pow_1)  # bias_correction1: "cuda:0 f32[]"
    # bias_correction2 = prims.sub(1.0, pow_2)  # bias_correction2: "cuda:0 f32[]"
    # prims.copy_(t7, L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_1_)
    # sqrt_1 = prims.sqrt(t15)  # sqrt_1: "cuda:0 f32[3, 2]"
    # t35 = prims.copy_(t15, L_self_state_list_L_self_state_keys_0_list_L_self_state_list_L_self_state_keys_0_keys_2_)  # t35: "cuda:0 f32[3, 2]"
    # bias_correction2_sqrt = prims.sqrt(bias_correction2)  # bias_correction2_sqrt: "cuda:0 f32[]"
    # step_size = prims.div(item, bias_correction1)  # step_size: "cuda:0 f32[]"
    # step_size_neg = prims.neg(step_size)  # step_size_neg: "cuda:0 f32[]"
    # mul_2 = prims.mul(bias_correction2_sqrt, step_size_neg)  # mul_2: "cuda:0 f32[]"
    # t26 = prims.broadcast_in_dim(mul_2, (3, 2), ())  # t26: "cuda:0 f32[3, 2]"
    # truediv_1 = prims.div(sqrt_1, t26)  # truediv_1: "cuda:0 f32[3, 2]"
    # truediv_2 = prims.div(item_3, step_size_neg)  # truediv_2: "cuda:0 f32[]"
    # t29 = prims.broadcast_in_dim(truediv_2, (3, 2), ())  # t29: "cuda:0 f32[3, 2]"
    # t30 = prims.add(truediv_1, t29)  # t30: "cuda:0 f32[3, 2]"
    # t32 = prims.div(t7, t30)  # t32: "cuda:0 f32[3, 2]"
    # t33 = prims.add(L_self_param_groups_0_params_0_, t32)  # t33: "cuda:0 f32[3, 2]"
    # addcdiv_ = prims.copy_(t33, L_self_param_groups_0_params_0_)  # addcdiv_: "cuda:0 f32[3, 2]"

You can see the representation here:

[..., t7, ...] = nvFusion0(..., argx, ...):
  # t7 = lerp(argx, ...)
  # prims.copy_(t7, argx)

We are requesting the t7 as the output from nvFusion, looking at the definition here, t7 shouldn't be aliased to argx. #1110 correctly fixed that.
Prior to the patch, we'll ended up returning a t7 as alias to argx, which is the wrong behavior, but we do save bandwidth, i.e. kernel time here.

I'm closing this issue as not planned, since optimizer is not of high priority at this moment.

@jjsjann123 jjsjann123 closed this as not planned Won't fix, can't repro, duplicate, stale Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request nvfuser thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
6 participants