lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

conv backward in thunder

Open t-vi opened this issue 1 year ago • 7 comments

while testing #797 , it seems that Thunder's backward might lead to things not being optimal (accurary, speed?):

def foo(x, w, b=None):
    return torch.nn.functional.conv2d(x, w, b)

x = torch.randn(1, 2, 8, 8, requires_grad=True)
w = torch.randn(3, 2, 4, 4, requires_grad=True)
b = torch.randn(3, requires_grad=True)
go = torch.randn(1, 3, 5, 5)

jfoo = thunder.jit(foo)

x64 = x.to(torch.float64)
w64 = w.to(torch.float64)
b64 = b.to(torch.float64)
ref_eager_out = foo(x64, w64, b64)
ref_eager_grads = torch.autograd.grad(ref_eager_out, [x64, w64, b64], go.to(torch.float64))

with torch.autocast("cpu", torch.float16):
    print("eager")
    with torch.profiler.profile() as prof:
        eager_out = foo(x, w, b)
        eager_grads = torch.autograd.grad(eager_out, [x, w, b], go)
    print(prof.key_averages().table())
    print("thunder")
    with torch.profiler.profile() as prof:
        jit_out = jfoo(x, w, b)
        jit_grads = torch.autograd.grad(jit_out, [x, w, b], go)
    print(prof.key_averages().table())

torch.testing.assert_close(eager_out, jit_out)


for eg, jg, rg in zip(eager_grads, jit_grads, ref_eager_grads):
    # TODO: tighten check?
    print(f"ref - eager {(eg - rg).abs().max().item():.4f} ref - thunder {(jg - rg).abs().max().item():.4f}")
    torch.testing.assert_close(eg, jg, atol=1e-2, rtol=1e-2)

gives (note the backward running conv2d forward twice again(?))

eager
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         4.64%      16.030us        99.37%     343.382us     171.691us             2  
                                               aten::to         2.59%       8.950us        20.87%      72.130us      10.304us             7  
                                         aten::_to_copy        10.49%      36.240us        18.28%      63.180us       9.026us             7  
                                    aten::empty_strided         4.68%      16.190us         4.68%      16.190us       2.313us             7  
                                            aten::copy_         6.26%      21.620us         6.26%      21.620us       2.702us             8  
                                      aten::convolution         4.14%      14.290us        40.00%     138.241us     138.241us             1  
                                     aten::_convolution         3.99%      13.790us        35.87%     123.951us     123.951us             1  
                                aten::_nnpack_available         0.25%       0.850us         0.25%       0.850us       0.425us             2  
                                      aten::thnn_conv2d         0.74%       2.570us        31.70%     109.561us     109.561us             1  
                             aten::_slow_conv2d_forward        24.09%      83.260us        30.96%     106.991us     106.991us             1  
                                            aten::empty         2.14%       7.390us         2.14%       7.390us       1.232us             6  
                                             aten::view         2.17%       7.491us         2.17%       7.491us       1.873us             4  
                                          aten::resize_         1.17%       4.060us         1.17%       4.060us       1.353us             3  
                                          aten::reshape         0.41%       1.400us         0.62%       2.150us       2.150us             1  
autograd::engine::evaluate_function: ConvolutionBack...         1.84%       6.350us        26.54%      91.701us      91.701us             1  
                                   ConvolutionBackward0         1.86%       6.430us        24.70%      85.351us      85.351us             1  
                             aten::convolution_backward         3.16%      10.911us        22.84%      78.921us      78.921us             1  
                            aten::_slow_conv2d_backward        10.20%      35.240us        19.61%      67.760us      67.760us             1  
                                       aten::resize_as_         0.46%       1.600us         0.86%       2.980us       2.980us             1  
                                            aten::zero_         0.40%       1.390us         0.40%       1.390us       0.695us             2  
                                              aten::sum         5.02%      17.360us         6.38%      22.050us      22.050us             1  
                                       aten::as_strided         0.52%       1.790us         0.52%       1.790us       1.790us             1  
                                            aten::fill_         0.84%       2.900us         0.84%       2.900us       2.900us             1  
autograd::engine::evaluate_function: ToCopyBackward0...         1.85%       6.410us         7.86%      27.150us       9.050us             3  
                                        ToCopyBackward0         1.22%       4.230us         6.00%      20.740us       6.913us             3  
autograd::engine::evaluate_function: torch::autograd...         0.54%       1.870us         0.54%       1.870us       0.623us             3  
                                  cudaDeviceSynchronize         4.33%      14.960us         4.33%      14.960us      14.960us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 345.572us

thunder
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::to         0.67%       4.030us         4.20%      25.310us       3.616us             7  
                                         aten::_to_copy         1.67%      10.060us         3.53%      21.280us       3.040us             7  
                                    aten::empty_strided         1.63%       9.790us         1.63%       9.790us       0.890us            11  
                                            aten::copy_         2.20%      13.260us         2.20%      13.260us       1.105us            12  
                                      aten::convolution         0.73%       4.390us        65.19%     392.634us     130.878us             3  
                                     aten::_convolution         1.44%       8.660us        64.46%     388.244us     129.415us             3  
                                aten::_nnpack_available         0.04%       0.260us         0.04%       0.260us       0.087us             3  
                                      aten::thnn_conv2d         0.29%       1.730us        62.19%     374.564us     124.855us             3  
                             aten::_slow_conv2d_forward        59.96%     361.134us        61.90%     372.834us     124.278us             3  
                                            aten::empty         0.69%       4.150us         0.69%       4.150us       0.593us             7  
                                             aten::view         0.84%       5.050us         0.84%       5.050us       0.842us             6  
                                          aten::resize_         0.34%       2.020us         0.34%       2.020us       0.673us             3  
                                          aten::reshape         1.22%       7.320us         2.41%      14.530us       2.906us             5  
                                        ThunderFunction         2.41%      14.510us         2.41%      14.510us      14.510us             1  
autograd::engine::evaluate_function: ThunderFunction...         0.64%       3.870us        34.79%     209.572us     209.572us             1  
                                ThunderFunctionBackward        17.19%     103.561us        34.15%     205.702us     205.702us             1  
                                          aten::permute         1.46%       8.770us         1.99%      11.980us       1.997us             6  
                                       aten::as_strided         0.64%       3.880us         0.64%       3.880us       0.485us             8  
                                              aten::sum         1.31%       7.920us         1.57%       9.480us       9.480us             1  
                                            aten::fill_         0.21%       1.240us         0.21%       1.240us       1.240us             1  
                                              aten::pad         0.59%       3.570us         2.54%      15.310us       5.103us             3  
                                  aten::constant_pad_nd         0.54%       3.230us         1.95%      11.740us       3.913us             3  
                                            aten::clone         0.62%       3.740us         2.08%      12.540us       3.135us             4  
                                   aten::_reshape_alias         0.63%       3.790us         0.63%       3.790us       1.895us             2  
                                             aten::flip         0.97%       5.830us         1.31%       7.890us       7.890us             1  
                                       aten::empty_like         0.26%       1.550us         0.50%       3.020us       1.510us             2  
                                       aten::contiguous         0.12%       0.730us         0.79%       4.760us       4.760us             1  
autograd::engine::evaluate_function: torch::autograd...         0.16%       0.960us         0.16%       0.960us       0.320us             3  
                                  cudaDeviceSynchronize         0.55%       3.330us         0.55%       3.330us       3.330us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 602.335us

Edit: I don't have the output for the accuracy that I edited into the script here, but my impression is that the accuracy of the Thunder backward is not worse than the eager one in this example and this example is not terribly relevant for perf. It's just that we should develop insight into what's going on because we will bump into the question.

I wonder if #655 is related, maybe the same method could provide a 0th-order analysis of what's going on.

cc @tfogal

t-vi avatar Jul 18 '24 10:07 t-vi

The backward for conv is conv again. Both for the input and for the weight. But let's indeed investigate the decomp performance. Maybe, again, we can replace it with the PyTorch kernel... Maybe some things we'd better upcast in the decomp (for example, reduction for bias, unless such things are taken care of by NVFuser)...

nikitaved avatar Jul 18 '24 10:07 nikitaved

The decomp working is what we see with #655 in fp6, too. But I think we would want to look into whether the decomp has perf or numerical accuracy impacts.

t-vi avatar Jul 18 '24 11:07 t-vi

The grad test are solid and super comprehensive, but these are done in high precision modes. We might loose things between the calls to convolutions though (I expect PyTorch to behave well there in forward, but I am not sure), when run in lower precision modes, unless NVFuser handles these things with grace. And I am not 100% sure this is the case...

nikitaved avatar Jul 18 '24 11:07 nikitaved

Maybe conv.backward in PyTorch does update grads in a single kernel I wonder? The best perf improvement I see is to make NVFuser to have its own native convolution support. I mean best in a way of not introducing changes to Thunder... Currently NVFuser will not claim conv, and convs in the backward decomposition will end up being 2 kernel runs.

cc @tfogal

nikitaved avatar Jul 18 '24 11:07 nikitaved

Thanks for tagging me. I am sure we will want convs in nvFuser eventually, but it'll be quite some time; it's not even on the roadmap as of today. Plus, convs are hard :-). I generally trust nvFuser to get the best perf for memory-bound workloads today but the situation is spotty for compute-bound workloads and convs are generally compute-bound.

I think we should pursue other approaches, e.g. as you mention seeing if PyTorch will have a single kernel here, or maybe we can just directly call cuDNN. cc @vedaanta

tfogal avatar Jul 18 '24 15:07 tfogal

Note that this is more for investigation and knowing what's going on than any set conclusion. Comparing to fp64 in this trivial example it seems that we're not worse than eager autocast in accuracy, so it might just be different.

t-vi avatar Jul 18 '24 15:07 t-vi

triage review:

  • we are doing strange things in backward, e.g. using forward kernels
  • we'll probably hit this in the proxy model once we get things working
  • eager actually isn't doing a great job here; opportunity not a fault we have today.
  • unknown how impactful this could be at present: somebody needs to dig in

tfogal avatar Jul 22 '24 15:07 tfogal