conv backward in thunder
while testing #797 , it seems that Thunder's backward might lead to things not being optimal (accurary, speed?):
def foo(x, w, b=None):
return torch.nn.functional.conv2d(x, w, b)
x = torch.randn(1, 2, 8, 8, requires_grad=True)
w = torch.randn(3, 2, 4, 4, requires_grad=True)
b = torch.randn(3, requires_grad=True)
go = torch.randn(1, 3, 5, 5)
jfoo = thunder.jit(foo)
x64 = x.to(torch.float64)
w64 = w.to(torch.float64)
b64 = b.to(torch.float64)
ref_eager_out = foo(x64, w64, b64)
ref_eager_grads = torch.autograd.grad(ref_eager_out, [x64, w64, b64], go.to(torch.float64))
with torch.autocast("cpu", torch.float16):
print("eager")
with torch.profiler.profile() as prof:
eager_out = foo(x, w, b)
eager_grads = torch.autograd.grad(eager_out, [x, w, b], go)
print(prof.key_averages().table())
print("thunder")
with torch.profiler.profile() as prof:
jit_out = jfoo(x, w, b)
jit_grads = torch.autograd.grad(jit_out, [x, w, b], go)
print(prof.key_averages().table())
torch.testing.assert_close(eager_out, jit_out)
for eg, jg, rg in zip(eager_grads, jit_grads, ref_eager_grads):
# TODO: tighten check?
print(f"ref - eager {(eg - rg).abs().max().item():.4f} ref - thunder {(jg - rg).abs().max().item():.4f}")
torch.testing.assert_close(eg, jg, atol=1e-2, rtol=1e-2)
gives (note the backward running conv2d forward twice again(?))
eager
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::conv2d 4.64% 16.030us 99.37% 343.382us 171.691us 2
aten::to 2.59% 8.950us 20.87% 72.130us 10.304us 7
aten::_to_copy 10.49% 36.240us 18.28% 63.180us 9.026us 7
aten::empty_strided 4.68% 16.190us 4.68% 16.190us 2.313us 7
aten::copy_ 6.26% 21.620us 6.26% 21.620us 2.702us 8
aten::convolution 4.14% 14.290us 40.00% 138.241us 138.241us 1
aten::_convolution 3.99% 13.790us 35.87% 123.951us 123.951us 1
aten::_nnpack_available 0.25% 0.850us 0.25% 0.850us 0.425us 2
aten::thnn_conv2d 0.74% 2.570us 31.70% 109.561us 109.561us 1
aten::_slow_conv2d_forward 24.09% 83.260us 30.96% 106.991us 106.991us 1
aten::empty 2.14% 7.390us 2.14% 7.390us 1.232us 6
aten::view 2.17% 7.491us 2.17% 7.491us 1.873us 4
aten::resize_ 1.17% 4.060us 1.17% 4.060us 1.353us 3
aten::reshape 0.41% 1.400us 0.62% 2.150us 2.150us 1
autograd::engine::evaluate_function: ConvolutionBack... 1.84% 6.350us 26.54% 91.701us 91.701us 1
ConvolutionBackward0 1.86% 6.430us 24.70% 85.351us 85.351us 1
aten::convolution_backward 3.16% 10.911us 22.84% 78.921us 78.921us 1
aten::_slow_conv2d_backward 10.20% 35.240us 19.61% 67.760us 67.760us 1
aten::resize_as_ 0.46% 1.600us 0.86% 2.980us 2.980us 1
aten::zero_ 0.40% 1.390us 0.40% 1.390us 0.695us 2
aten::sum 5.02% 17.360us 6.38% 22.050us 22.050us 1
aten::as_strided 0.52% 1.790us 0.52% 1.790us 1.790us 1
aten::fill_ 0.84% 2.900us 0.84% 2.900us 2.900us 1
autograd::engine::evaluate_function: ToCopyBackward0... 1.85% 6.410us 7.86% 27.150us 9.050us 3
ToCopyBackward0 1.22% 4.230us 6.00% 20.740us 6.913us 3
autograd::engine::evaluate_function: torch::autograd... 0.54% 1.870us 0.54% 1.870us 0.623us 3
cudaDeviceSynchronize 4.33% 14.960us 4.33% 14.960us 14.960us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 345.572us
thunder
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::to 0.67% 4.030us 4.20% 25.310us 3.616us 7
aten::_to_copy 1.67% 10.060us 3.53% 21.280us 3.040us 7
aten::empty_strided 1.63% 9.790us 1.63% 9.790us 0.890us 11
aten::copy_ 2.20% 13.260us 2.20% 13.260us 1.105us 12
aten::convolution 0.73% 4.390us 65.19% 392.634us 130.878us 3
aten::_convolution 1.44% 8.660us 64.46% 388.244us 129.415us 3
aten::_nnpack_available 0.04% 0.260us 0.04% 0.260us 0.087us 3
aten::thnn_conv2d 0.29% 1.730us 62.19% 374.564us 124.855us 3
aten::_slow_conv2d_forward 59.96% 361.134us 61.90% 372.834us 124.278us 3
aten::empty 0.69% 4.150us 0.69% 4.150us 0.593us 7
aten::view 0.84% 5.050us 0.84% 5.050us 0.842us 6
aten::resize_ 0.34% 2.020us 0.34% 2.020us 0.673us 3
aten::reshape 1.22% 7.320us 2.41% 14.530us 2.906us 5
ThunderFunction 2.41% 14.510us 2.41% 14.510us 14.510us 1
autograd::engine::evaluate_function: ThunderFunction... 0.64% 3.870us 34.79% 209.572us 209.572us 1
ThunderFunctionBackward 17.19% 103.561us 34.15% 205.702us 205.702us 1
aten::permute 1.46% 8.770us 1.99% 11.980us 1.997us 6
aten::as_strided 0.64% 3.880us 0.64% 3.880us 0.485us 8
aten::sum 1.31% 7.920us 1.57% 9.480us 9.480us 1
aten::fill_ 0.21% 1.240us 0.21% 1.240us 1.240us 1
aten::pad 0.59% 3.570us 2.54% 15.310us 5.103us 3
aten::constant_pad_nd 0.54% 3.230us 1.95% 11.740us 3.913us 3
aten::clone 0.62% 3.740us 2.08% 12.540us 3.135us 4
aten::_reshape_alias 0.63% 3.790us 0.63% 3.790us 1.895us 2
aten::flip 0.97% 5.830us 1.31% 7.890us 7.890us 1
aten::empty_like 0.26% 1.550us 0.50% 3.020us 1.510us 2
aten::contiguous 0.12% 0.730us 0.79% 4.760us 4.760us 1
autograd::engine::evaluate_function: torch::autograd... 0.16% 0.960us 0.16% 0.960us 0.320us 3
cudaDeviceSynchronize 0.55% 3.330us 0.55% 3.330us 3.330us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 602.335us
Edit: I don't have the output for the accuracy that I edited into the script here, but my impression is that the accuracy of the Thunder backward is not worse than the eager one in this example and this example is not terribly relevant for perf. It's just that we should develop insight into what's going on because we will bump into the question.
I wonder if #655 is related, maybe the same method could provide a 0th-order analysis of what's going on.
cc @tfogal
The backward for conv is conv again. Both for the input and for the weight. But let's indeed investigate the decomp performance. Maybe, again, we can replace it with the PyTorch kernel... Maybe some things we'd better upcast in the decomp (for example, reduction for bias, unless such things are taken care of by NVFuser)...
The decomp working is what we see with #655 in fp6, too. But I think we would want to look into whether the decomp has perf or numerical accuracy impacts.
The grad test are solid and super comprehensive, but these are done in high precision modes. We might loose things between the calls to convolutions though (I expect PyTorch to behave well there in forward, but I am not sure), when run in lower precision modes, unless NVFuser handles these things with grace. And I am not 100% sure this is the case...
Maybe conv.backward in PyTorch does update grads in a single kernel I wonder? The best perf improvement I see is to make NVFuser to have its own native convolution support. I mean best in a way of not introducing changes to Thunder... Currently NVFuser will not claim conv, and convs in the backward decomposition will end up being 2 kernel runs.
cc @tfogal
Thanks for tagging me. I am sure we will want convs in nvFuser eventually, but it'll be quite some time; it's not even on the roadmap as of today. Plus, convs are hard :-). I generally trust nvFuser to get the best perf for memory-bound workloads today but the situation is spotty for compute-bound workloads and convs are generally compute-bound.
I think we should pursue other approaches, e.g. as you mention seeing if PyTorch will have a single kernel here, or maybe we can just directly call cuDNN. cc @vedaanta
Note that this is more for investigation and knowing what's going on than any set conclusion. Comparing to fp64 in this trivial example it seems that we're not worse than eager autocast in accuracy, so it might just be different.
triage review:
- we are doing strange things in backward, e.g. using forward kernels
- we'll probably hit this in the proxy model once we get things working
- eager actually isn't doing a great job here; opportunity not a fault we have today.
- unknown how impactful this could be at present: somebody needs to dig in