functorch
functorch copied to clipboard
25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian
Hi developers,
After I upgraded functorch from v0.1.1
to 0.2.0
, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.
Please let me know if I did anything wrong, and also whether the perf regression could be fixed. Thanks!
Benchmark result
Benchmark result on NVIDIA A100
# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred error: functorch: 0.00e+00
max hessian error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms
# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred error: functorch: 1.49e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms (0.75 X)
Benchmark result on NVIDIA A6000
# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred error: functorch: 1.49e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms
# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred error: functorch: 1.86e-08
max hessian error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)
benchmark script
benchmark.py
import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2 # x, y
D2 = 3 # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False
model = nn.Sequential(
nn.Linear(D1, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, D2),
).to(device)
def predict(x):
torch.cuda.nvtx.range_push("forward")
out = model(x)
torch.cuda.nvtx.range_pop()
return out, out # return two outputs is needed for jacrev auxiliary object
def reference_hessian():
x_ = x.clone().requires_grad_()
ones = torch.ones(B, device=x.device)
pred, _ = predict(x_)
jacobian_rows = [None] * D2
hessian_rows = [None] * (D2 * D1)
for i in range(D2):
torch.cuda.nvtx.range_push("autograd jacobian")
jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
0
]
torch.cuda.nvtx.range_pop()
for i in range(D2):
for j in range(D1):
torch.cuda.nvtx.range_push("autograd hesian")
hessian_rows[i * D1 + j] = torch.autograd.grad(
jacobian_rows[i][:, j], x_, ones, create_graph=True
)[0]
torch.cuda.nvtx.range_pop()
jacobian = torch.stack(jacobian_rows) # [D2, B, D1]
hessian = torch.stack(hessian_rows) # [D2 * D1, B, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian.transpose(0, 1), pred
def functorch_hessian():
x_ = x.clone().requires_grad_()
hessian, pred = vmap(
jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
in_dims=0,
)(
x_
) # [B, D2, D1, D1]
if run_backward:
l = hessian.sum()
l.backward()
return hessian, pred
def validate_result():
# test functorch result
ref_hes, ref_pred = reference_hessian()
ft_hes, ft_pred = functorch_hessian()
ref_hes = ref_hes.view_as(ft_hes)
print(f"max pred error: functorch: {(ref_pred - ft_pred).max():.2e}")
print(f"max hessian error: functorch: {(ref_hes - ft_hes).max():.2e}")
def benchmark(func):
N = 20
torch.cuda.synchronize()
start = time.time()
for i in range(N):
torch.cuda.nvtx.range_push(func.__name__)
_ = func()
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
time_ms = ((time.time() - start) / N) * 1000
print(f"{func.__name__}: {time_ms:.3f} ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backward", default=False, action="store_true")
args = parser.parse_args()
if args.backward:
run_backward = True
print("===== benchmark with backward =====")
else:
print("===== benchmark without backward =====")
validate_result()
# warm up
for i in range(10):
reference_hessian()
functorch_hessian()
# benchmark hessian
benchmark(reference_hessian)
benchmark(functorch_hessian)
ping @samdow @zou3519
Thanks for the report, we'll take a look soon
Bisected to https://github.com/pytorch/pytorch/pull/75195/files. https://github.com/pytorch/pytorch/pull/75195/files by itself may not be a problem, perhaps the problem is our batching rule for mv.
@yueyericardo is the repro you provided the entire model, or is it a subset of some model that you're running?
cc @lezcano @ezyang for https://github.com/pytorch/pytorch/pull/75195 -- this led to a performance regression in functorch. I'm not sure what the original intent of the PR is (there are no tests). I'm still trying to root cause this, but it is a bit difficult to visualize. What are the chances we could revert that PR?
(I've confirmed that reverting that single PR on pytorch/pytorch master makes the performance regression go away)
Great thanks to @zou3519 for the quick debugging!! I'm working on NVIDIA Modulus project, we are using functorch because it provides a lot of perfs for the Jacobian and Hessian calculations. The minimal repro I provided is only a subset of our model to demonstrate the performance regression.
Thanks again!
@yueyericardo - for the original model itself, is the performance regression also 25%, or is it a smaller number? Is the original model public? One thing we can do to prevent future regressions is to check the original model into https://github.com/pytorch/benchmark.
I've noticed a lot of other similar models where folks have a nn.Sequential that is just made up of nn.Linear and activations and need to compute a vmap(jacrev or vmap(hessian of the quantity, so we could also potentially just check your script into torchbench otherwise.
Hi @zou3519 Our source code is free to download from the website, but it is not developed on GitHub. And our code base is also might too large to put into pytorch/benchmark.
Yes, exactly! I believe the minimal repro I provided is enough to prevent future regression for our model. Thanks!
I think that rather than blindly reverting, we should get to the root of the problem, as it is very weird to get such a regression when dispatching from a more general function to a more concrete (that was the reason for that PR).
Things that come to mind are:
- Is this regression architecture / cublas version - dependent?
- Is this regresion also happening on regular matmul for that codepath?
If the answer to the above two is no, then this performance issue is likely on the functorch end and should be fixed. Otherwise, it's on the cuBLAS end and should be reported to NVIDIA
cc @ngimel @xwang233 @ivanyashchuk
@Lezcano It's fair to submit the upstream bugs, but if we know that our upstream library's general kernel has better perf than a specialized one, we might as well use it.
@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?
Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.
No, I don't think so. The PR is supposed to make the kernel run faster.
fwiw, I think this may be related to the open PR I have to avoid copies in matmul. Could you check whether https://github.com/pytorch/pytorch/pull/76828 fixes this?
In any case, I'm fine with reverting, but we should investigate what's causing this regardless
@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?
Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.
That PR is fixing spurious resize warnings that were previously generated, and by itself is supposed to speedup things by avoiding squeeze/unsqueeze calls which are not free (and especially not free when autograd is needed). As for more general/more concrete function performance, we should investigate this, but I doubt that's the case.
@zou3519 can you by any chance collect profiling results for old and new versions?
In any case, I'm fine with reverting, but we should investigate what's causing this regardless
I agree this warrants more investigation. We've got a problem in that there is a timeline for 1.12.1, and I am not sure how long it is going to take to actually get to the bottom of this.
@zou3519 can you by any chance collect profiling results for old and new versions?
I can try but it's been a long time since I touched nvprof or nsight profiler, so I will need to relearn the magic invocations.
but we should investigate what's causing this regardless
Since we changed the mm to mv, functorch generates different code for the vmap(jacrev(jacfwd(mm)) as opposed to vmap(jacrev(jacfwd(mv)). It's plausible that the problem is that "functorch should generate better code"; we're still digging into it
Isn't 1.12.1 done already?
No nsight needed, just torch profiler should be enough. (with torch.profiler.profile as p():
and print key averages and export chrome trace in the end).
FYI, the nsight profiling result before
Time Total Time Instances Avg Med Min Max StdDev Name
64.7% 928.624 ms 1000 928.623 μs 867.621 μs 297.345 μs 1.686 ms 493.406 μs ampere_sgemm_128x64_nn
10.3% 147.304 ms 250 589.217 μs 589.187 μs 588.643 μs 590.435 μs 275 ns ampere_sgemm_128x128_tn
9.4% 134.966 ms 1300 103.819 μs 70.432 μs 5.344 μs 194.209 μs 75.470 μs void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, at::native::AddFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
6.8% 98.146 ms 900 109.050 μs 94.848 μs 70.081 μs 177.921 μs 43.651 μs void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
5.4% 77.550 ms 350 221.572 μs 300.321 μs 18.080 μs 302.050 μs 124.914 μs ampere_sgemm_128x64_tn
0.8% 10.981 ms 1450 7.572 μs 3.552 μs 2.880 μs 26.112 μs 7.973 μs void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.7% 9.949 ms 150 66.324 μs 82.928 μs 27.968 μs 88.672 μs 26.863 μs ampere_sgemm_32x32_sliced1x4_nn
0.6% 8.074 ms 300 26.912 μs 26.976 μs 25.696 μs 27.776 μs 390 ns void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.5% 7.713 ms 50 154.257 μs 154.177 μs 153.184 μs 155.905 μs 634 ns ampere_sgemm_32x128_nn
0.3% 4.847 ms 100 48.468 μs 47.664 μs 34.688 μs 63.393 μs 13.397 μs ampere_sgemm_32x32_sliced1x4_tn
0.2% 3.342 ms 650 5.141 μs 5.344 μs 4.256 μs 5.728 μs 423 ns void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
0.1% 2.028 ms 50 40.569 μs 40.512 μs 40.032 μs 41.536 μs 272 ns ampere_sgemm_64x64_nn
0.1% 994.305 μs 50 19.886 μs 19.856 μs 19.649 μs 20.896 μs 169 ns ampere_sgemm_128x32_nn
0.0% 435.235 μs 100 4.352 μs 4.384 μs 4.160 μs 4.576 μs 100 ns void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
after (the third row 18.4% on a copy kernel)
Time Total Time Instances Avg Med Min Max StdDev Name
21.8% 433.409 ms 500 866.818 μs 866.756 μs 865.028 μs 870.116 μs 881 ns ampere_sgemm_32x128_nt
21.2% 420.235 ms 250 1.681 ms 1.681 ms 1.678 ms 1.685 ms 813 ns ampere_sgemm_128x64_nn
18.4% 365.797 ms 2850 128.349 μs 83.809 μs 4.096 μs 608.931 μs 173.023 μs void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
10.8% 213.750 ms 300 712.501 μs 588.802 μs 588.195 μs 1.335 ms 277.174 μs ampere_sgemm_128x128_tn
8.1% 160.586 ms 1300 123.527 μs 93.057 μs 5.089 μs 266.497 μs 92.181 μs void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
7.9% 156.152 ms 500 312.304 μs 312.290 μs 310.145 μs 314.689 μs 781 ns ampere_sgemm_128x32_nn
5.9% 117.864 ms 900 130.960 μs 129.824 μs 77.504 μs 203.681 μs 49.220 μs void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::<unnamed>::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
3.3% 66.223 ms 50 1.324 ms 1.324 ms 1.324 ms 1.326 ms 397 ns ampere_sgemm_128x128_tt
1.1% 21.094 ms 1850 11.402 μs 3.584 μs 2.816 μs 65.857 μs 12.662 μs void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.4% 8.071 ms 300 26.904 μs 26.976 μs 25.792 μs 28.128 μs 392 ns void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.4% 7.903 ms 50 158.063 μs 158.049 μs 156.609 μs 159.617 μs 634 ns ampere_sgemm_32x128_nn
0.4% 7.359 ms 100 73.594 μs 73.600 μs 72.000 μs 78.368 μs 1.223 μs ampere_sgemm_32x32_sliced1x4_nt
0.2% 3.242 ms 50 64.844 μs 64.928 μs 62.464 μs 65.761 μs 528 ns ampere_sgemm_32x32_sliced1x4_tn
0.1% 2.255 ms 100 22.551 μs 22.688 μs 20.865 μs 24.192 μs 1.106 μs void gemmSN_NN_kernel<float, (int)256, (int)4, (int)2, (int)8, (int)3, (int)4, (bool)0, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<T9, T10, T11, T1>)
0.1% 2.043 ms 100 20.427 μs 20.416 μs 20.160 μs 22.656 μs 257 ns ampere_sgemm_128x32_tn
0.0% 432.292 μs 100 4.322 μs 4.336 μs 4.128 μs 4.640 μs 97 ns void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
@ngimel the PR that fixed the warnings was already merged, this one is just concerned about avoiding copies. One of the cases where it elides a copy is when you multiply a matrix by a batch of matrices. This is exactly the batched version of the vector-matrix product that"s causing the regression. That's why I think it may fix it
That must be coming from functorch, as that PR doesn't introduce any additional copies.
yup, I think https://github.com/pytorch/pytorch/pull/76828 would fix the regression
I patched in https://github.com/pytorch/pytorch/pull/76828 and the above script ends up OOM-ing :/
Isn't 1.12.1 done already?
Not yet, we have a chance to change it (or request a 1.12.2 if necessary since this regression is large and these types of models are typical functorch usage)
Re. OOM. Wow, that's certainly unexpected. I'm not sure what's the best way to follow up on that. Probably @ngimel has a better idea how to proceed.
Regardless, landing that PR (the avoid copies matmul...) will be tricky due to some discrepancies in the accuracy of mm vs mv that we found for float16. As such, I think the most realistic way forward would be to revert the offending PR and investigate what's causing that OMM after 1.12.1 is released
the PR that fixed the warnings was already merged
I don't think that's true, #75195 itself is fixing a warning (otherwise a user-supplied correct 1d out was sent to mm, and mm complained about resizing it to 2d).
I think the one that fixed the out version was https://github.com/pytorch/pytorch/pull/75197 and a previous one in that stack, but it may be the case that 75195 also fixed an out warning, I don't remember now.
Here is my attempt at a smaller repro that runs on a single version of pytorch. It computes a different quantity, vmap(jacrev(smaller_predict))
, but also exhibits performance differences (1.3ms vs 7ms on my machine).
https://gist.github.com/zou3519/3869d460f8bcb12799967e08a5998d9c
In the gist there's also a trace (acquired via make_fx) of what the graph looks like, using the "old linear" (aka the implementation of matmul in 1.11) vs the "new linear" (the implementation of matmul in 1.12), if anyone's feeling ambitious about reading traces
@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)
Hi @zou3519, thanks for the following up! We are having some internal discussions regarding this and will come back to you tomorrow.
From looking at this a bit, I think what happened is:
- pytorch changes unsqueeze + mm -> mv in case where first element is vector
- this example hits that codepath under vmap => functorch's mv kernel
- vector is batched and matrix is not => functorch moves the vector's batch dim to the end and matmuls. Batch dim is now the 1st dimension, not the 0th
- relu's output is now BatchedTensor([512, 10000], lvl=0, dim=1). This needs to be saved as an activation for the backwards
During backwards,
- grad_output is a BatchedTensor(BatchedTensor([10000, 512, 6], lvl=0, dim =0), lvl=1, dim=2)
- threshold_backward (relu's backward) takes in grad_output and relu (binary pointwise batch rule). Starts with the lvl 1 batch rule
- grad_output is the only value batched at the highest level, we move its batch dim to the front with a movedim. grad_output is no longer contiguous
- redispatch on threshold_backward for lvl 0
- move relu's bdim to front with a movedim and pad to be 3D. relu is no longer contiguous grad_output does not need to since it's lvl 0 bdim is 0. It is still contiguous
- the view on relu is not contiguous, the view on grad_output is not contiguous => threshold_backward triggers copy Locally I've tested where only one of the inputs is not contiguous and haven't seen the same slowdowns.
We can also validate that this is our issue since we hit pre-regression performance numbers by changing functorch's mv batch rule here:
auto other_ = moveBatchDimToFront(other, other_bdim);
auto self_ = at::movedim(self, 0, 1);
auto result = at::matmul(other_, self_);
return std::make_tuple( std::move(result), 0 );
This doesn't trigger the copy since the batch dimension for the saved relu activation is now the first dimension. However, this may hit other perf issue from the transposing both self and result
This is the smallest subset of @zou3519's trace where I can see the perf differences
mm_2 = torch.randn(60000, 512)
relu = torch.randn(10000, 512)
def old_linear_faster():
mm_2_view = mm_2.view([10000, 6, 1, 512])
mm_2_view_squeeze = mm_2_view.squeeze(2)
relu_view = relu.view([10000, 1, 512])
threshold_backward = torch.ops.aten.threshold_backward(mm_2_view_squeeze, relu_view, 0)
return threshold_backward
bmm = torch.randn(10000, 512, 6)
other_relu = torch.randn(512, 10000)
def new_linear_faster():
bmm_view = bmm.view([10000, 512, 6])
bmm_view_permute = bmm_view.permute([0, 2, 1])
other_relu_permute = other_relu.permute([1, 0])
other_relu_permute_view = other_relu_permute.view([10000, 1, 512])
threshold_backward = torch.ops.aten.threshold_backward(bmm_view_permute, other_relu_permute_view, 0)
return threshold_backward
notably, if If I change the views on either bmm or other_relu to be contiguous constants, it has much faster performance. So it seems like threshold_backwards doesn't copy if only one of its inputs is not contiguous but does if both are not
Thanks @samdow, copy from relu definitely seems to affect perf, however there's also another copy coming from MvBackward
Also, why does threshold_backward on discontiguous inputs trigger a copy? In eager threshold_backward should be able to handle them via tensorIterator