Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conv backward in thunder #799

Open
t-vi opened this issue Jul 18, 2024 · 7 comments
Open

conv backward in thunder #799

t-vi opened this issue Jul 18, 2024 · 7 comments

Comments

@t-vi
Copy link
Collaborator

t-vi commented Jul 18, 2024

while testing #797 , it seems that Thunder's backward might lead to things not being optimal (accurary, speed?):

def foo(x, w, b=None):
    return torch.nn.functional.conv2d(x, w, b)

x = torch.randn(1, 2, 8, 8, requires_grad=True)
w = torch.randn(3, 2, 4, 4, requires_grad=True)
b = torch.randn(3, requires_grad=True)
go = torch.randn(1, 3, 5, 5)

jfoo = thunder.jit(foo)

x64 = x.to(torch.float64)
w64 = w.to(torch.float64)
b64 = b.to(torch.float64)
ref_eager_out = foo(x64, w64, b64)
ref_eager_grads = torch.autograd.grad(ref_eager_out, [x64, w64, b64], go.to(torch.float64))

with torch.autocast("cpu", torch.float16):
    print("eager")
    with torch.profiler.profile() as prof:
        eager_out = foo(x, w, b)
        eager_grads = torch.autograd.grad(eager_out, [x, w, b], go)
    print(prof.key_averages().table())
    print("thunder")
    with torch.profiler.profile() as prof:
        jit_out = jfoo(x, w, b)
        jit_grads = torch.autograd.grad(jit_out, [x, w, b], go)
    print(prof.key_averages().table())

torch.testing.assert_close(eager_out, jit_out)


for eg, jg, rg in zip(eager_grads, jit_grads, ref_eager_grads):
    # TODO: tighten check?
    print(f"ref - eager {(eg - rg).abs().max().item():.4f} ref - thunder {(jg - rg).abs().max().item():.4f}")
    torch.testing.assert_close(eg, jg, atol=1e-2, rtol=1e-2)

gives (note the backward running conv2d forward twice again(?))

eager
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         4.64%      16.030us        99.37%     343.382us     171.691us             2  
                                               aten::to         2.59%       8.950us        20.87%      72.130us      10.304us             7  
                                         aten::_to_copy        10.49%      36.240us        18.28%      63.180us       9.026us             7  
                                    aten::empty_strided         4.68%      16.190us         4.68%      16.190us       2.313us             7  
                                            aten::copy_         6.26%      21.620us         6.26%      21.620us       2.702us             8  
                                      aten::convolution         4.14%      14.290us        40.00%     138.241us     138.241us             1  
                                     aten::_convolution         3.99%      13.790us        35.87%     123.951us     123.951us             1  
                                aten::_nnpack_available         0.25%       0.850us         0.25%       0.850us       0.425us             2  
                                      aten::thnn_conv2d         0.74%       2.570us        31.70%     109.561us     109.561us             1  
                             aten::_slow_conv2d_forward        24.09%      83.260us        30.96%     106.991us     106.991us             1  
                                            aten::empty         2.14%       7.390us         2.14%       7.390us       1.232us             6  
                                             aten::view         2.17%       7.491us         2.17%       7.491us       1.873us             4  
                                          aten::resize_         1.17%       4.060us         1.17%       4.060us       1.353us             3  
                                          aten::reshape         0.41%       1.400us         0.62%       2.150us       2.150us             1  
autograd::engine::evaluate_function: ConvolutionBack...         1.84%       6.350us        26.54%      91.701us      91.701us             1  
                                   ConvolutionBackward0         1.86%       6.430us        24.70%      85.351us      85.351us             1  
                             aten::convolution_backward         3.16%      10.911us        22.84%      78.921us      78.921us             1  
                            aten::_slow_conv2d_backward        10.20%      35.240us        19.61%      67.760us      67.760us             1  
                                       aten::resize_as_         0.46%       1.600us         0.86%       2.980us       2.980us             1  
                                            aten::zero_         0.40%       1.390us         0.40%       1.390us       0.695us             2  
                                              aten::sum         5.02%      17.360us         6.38%      22.050us      22.050us             1  
                                       aten::as_strided         0.52%       1.790us         0.52%       1.790us       1.790us             1  
                                            aten::fill_         0.84%       2.900us         0.84%       2.900us       2.900us             1  
autograd::engine::evaluate_function: ToCopyBackward0...         1.85%       6.410us         7.86%      27.150us       9.050us             3  
                                        ToCopyBackward0         1.22%       4.230us         6.00%      20.740us       6.913us             3  
autograd::engine::evaluate_function: torch::autograd...         0.54%       1.870us         0.54%       1.870us       0.623us             3  
                                  cudaDeviceSynchronize         4.33%      14.960us         4.33%      14.960us      14.960us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 345.572us

thunder
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::to         0.67%       4.030us         4.20%      25.310us       3.616us             7  
                                         aten::_to_copy         1.67%      10.060us         3.53%      21.280us       3.040us             7  
                                    aten::empty_strided         1.63%       9.790us         1.63%       9.790us       0.890us            11  
                                            aten::copy_         2.20%      13.260us         2.20%      13.260us       1.105us            12  
                                      aten::convolution         0.73%       4.390us        65.19%     392.634us     130.878us             3  
                                     aten::_convolution         1.44%       8.660us        64.46%     388.244us     129.415us             3  
                                aten::_nnpack_available         0.04%       0.260us         0.04%       0.260us       0.087us             3  
                                      aten::thnn_conv2d         0.29%       1.730us        62.19%     374.564us     124.855us             3  
                             aten::_slow_conv2d_forward        59.96%     361.134us        61.90%     372.834us     124.278us             3  
                                            aten::empty         0.69%       4.150us         0.69%       4.150us       0.593us             7  
                                             aten::view         0.84%       5.050us         0.84%       5.050us       0.842us             6  
                                          aten::resize_         0.34%       2.020us         0.34%       2.020us       0.673us             3  
                                          aten::reshape         1.22%       7.320us         2.41%      14.530us       2.906us             5  
                                        ThunderFunction         2.41%      14.510us         2.41%      14.510us      14.510us             1  
autograd::engine::evaluate_function: ThunderFunction...         0.64%       3.870us        34.79%     209.572us     209.572us             1  
                                ThunderFunctionBackward        17.19%     103.561us        34.15%     205.702us     205.702us             1  
                                          aten::permute         1.46%       8.770us         1.99%      11.980us       1.997us             6  
                                       aten::as_strided         0.64%       3.880us         0.64%       3.880us       0.485us             8  
                                              aten::sum         1.31%       7.920us         1.57%       9.480us       9.480us             1  
                                            aten::fill_         0.21%       1.240us         0.21%       1.240us       1.240us             1  
                                              aten::pad         0.59%       3.570us         2.54%      15.310us       5.103us             3  
                                  aten::constant_pad_nd         0.54%       3.230us         1.95%      11.740us       3.913us             3  
                                            aten::clone         0.62%       3.740us         2.08%      12.540us       3.135us             4  
                                   aten::_reshape_alias         0.63%       3.790us         0.63%       3.790us       1.895us             2  
                                             aten::flip         0.97%       5.830us         1.31%       7.890us       7.890us             1  
                                       aten::empty_like         0.26%       1.550us         0.50%       3.020us       1.510us             2  
                                       aten::contiguous         0.12%       0.730us         0.79%       4.760us       4.760us             1  
autograd::engine::evaluate_function: torch::autograd...         0.16%       0.960us         0.16%       0.960us       0.320us             3  
                                  cudaDeviceSynchronize         0.55%       3.330us         0.55%       3.330us       3.330us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 602.335us

Edit: I don't have the output for the accuracy that I edited into the script here, but my impression is that the accuracy of the Thunder backward is not worse than the eager one in this example and this example is not terribly relevant for perf. It's just that we should develop insight into what's going on because we will bump into the question.

I wonder if #655 is related, maybe the same method could provide a 0th-order analysis of what's going on.

cc @tfogal

@nikitaved
Copy link
Contributor

nikitaved commented Jul 18, 2024

The backward for conv is conv again. Both for the input and for the weight. But let's indeed investigate the decomp performance. Maybe, again, we can replace it with the PyTorch kernel... Maybe some things we'd better upcast in the decomp (for example, reduction for bias, unless such things are taken care of by NVFuser)...

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 18, 2024

The decomp working is what we see with #655 in fp6, too. But I think we would want to look into whether the decomp has perf or numerical accuracy impacts.

@nikitaved
Copy link
Contributor

nikitaved commented Jul 18, 2024

The grad test are solid and super comprehensive, but these are done in high precision modes. We might loose things between the calls to convolutions though (I expect PyTorch to behave well there in forward, but I am not sure), when run in lower precision modes, unless NVFuser handles these things with grace. And I am not 100% sure this is the case...

@nikitaved
Copy link
Contributor

nikitaved commented Jul 18, 2024

Maybe conv.backward in PyTorch does update grads in a single kernel I wonder? The best perf improvement I see is to make NVFuser to have its own native convolution support. I mean best in a way of not introducing changes to Thunder...
Currently NVFuser will not claim conv, and convs in the backward decomposition will end up being 2 kernel runs.

cc @tfogal

@tfogal
Copy link
Collaborator

tfogal commented Jul 18, 2024

Thanks for tagging me. I am sure we will want convs in nvFuser eventually, but it'll be quite some time; it's not even on the roadmap as of today. Plus, convs are hard :-). I generally trust nvFuser to get the best perf for memory-bound workloads today but the situation is spotty for compute-bound workloads and convs are generally compute-bound.

I think we should pursue other approaches, e.g. as you mention seeing if PyTorch will have a single kernel here, or maybe we can just directly call cuDNN. cc @vedaanta

@t-vi t-vi removed the nvfuser label Jul 18, 2024
@t-vi
Copy link
Collaborator Author

t-vi commented Jul 18, 2024

Note that this is more for investigation and knowing what's going on than any set conclusion. Comparing to fp64 in this trivial example it seems that we're not worse than eager autocast in accuracy, so it might just be different.

@tfogal
Copy link
Collaborator

tfogal commented Jul 22, 2024

triage review:

  • we are doing strange things in backward, e.g. using forward kernels
  • we'll probably hit this in the proxy model once we get things working
  • eager actually isn't doing a great job here; opportunity not a fault we have today.
  • unknown how impactful this could be at present: somebody needs to dig in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants