functorch
functorch copied to clipboard
No Batching rules for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet causes significant slowdown for per-sample gradients with torch.linalg.slogdet
TL;DR - torch.linalg.slogdet
is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (1.13.0.dev20220721
/ 0.3.0a0+e8a68f4
) than a previous version of PyTorch/FuncTorch ( 1.12.0a0+git7c2103a
/ 0.2.0a0+9d6ee76
) compiled from source. This seems to be due to the lack of batching rules for aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
.
Thanks! :)
Hi All,
I've recently noticed that my code significantly slowed down (by around an order of magnitude) when moving from PyTorch 1.12 to 1.13. I've made a minimal reproducible example to highlight this issue. For reference, this issue was starting from #979 with some more info there, although the issue has been solved and a new issue was open as per @vfdev-5 suggestion.
The MRE below computes per-sample gradients with respect to the parameters for the laplacian of a model w.r.t its inputs. The script will compute the per-sample gradients for N
inputs from 1 to 6 and show the walltime, then I decide to use torch.profile.profiler
to give a more clear benchmark for N=4
.
I've benchmarked two versions of PyTorch/FuncTorch. The first version was made from source (and can be found here). The only thing that is changed is the slogdet_backward
formula which you can find here. The full version for this "old-source" version is,
PyTorch version: 1.12.0a0+git7c2103a
CUDA version: 11.6
FuncTorch version: 0.2.0a0+9d6ee76
The other version is the latest nightly (hereafter referred to as "nightly"). The full version of this "nightly" version is,
PyTorch version: 1.13.0.dev20220721
CUDA version: 11.6
FuncTorch version: 0.3.0a0+e8a68f4
A comparison in walltime (measured in seconds) as N
increases from 1 to 6 is as follows
N | [old-source] | [nightly]
1 | 0.5719 | 2.4907 #first call is slow because ?
1 | 0.0133 | 2.0593
2 | 0.0870 | 2.4496
3 | 0.1153 | 2.9293
4 | 0.1129 | 3.3715
5 | 0.1576 | 3.8302
6 | 0.2059 | 4.2622
The torch.profile.profiler
case of N
= 4 for the "old-source" version is shown below and is stored by cuda_time_total
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::matmul 0.38% 417.000us 4.25% 4.694ms 82.351us 0.000us 0.00% 127.220ms 2.232ms 57
aten::mm 0.69% 765.000us 3.77% 4.169ms 64.138us 25.971ms 24.41% 117.336ms 1.805ms 65
aten::bmm 0.45% 493.000us 1.06% 1.168ms 43.259us 64.058ms 60.20% 87.447ms 3.239ms 27
autograd::engine::evaluate_function: MmBackward0 0.12% 131.000us 1.37% 1.513ms 168.111us 0.000us 0.00% 42.798ms 4.755ms 9
MmBackward0 0.04% 40.000us 1.22% 1.347ms 149.667us 0.000us 0.00% 42.630ms 4.737ms 9
volta_dgemm_64x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 41.116ms 38.64% 41.116ms 4.112ms 10
autograd::engine::evaluate_function: AddmmBackward0 0.09% 103.000us 1.71% 1.890ms 189.000us 0.000us 0.00% 21.334ms 2.133ms 10
volta_dgemm_128x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 20.883ms 19.62% 20.883ms 3.481ms 6
AddmmBackward0 0.04% 39.000us 1.13% 1.254ms 125.400us 0.000us 0.00% 19.551ms 1.955ms 10
volta_dgemm_64x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 14.590ms 13.71% 14.590ms 2.432ms 6
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 110.510ms
Self CUDA time total: 106.412ms
However, in the case of using the latest "nightly" version. The MRE significantly slows down and the torch.profile.profiler
is dominated by the following commands aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::_linalg_solve_ex 14.79% 1.125s 141.49% 10.763s 212.108us 0.000us 0.00% 2.966s 58.452us 50741
aten::linalg_solve 0.09% 7.174ms 117.89% 8.967s 560.456ms 0.000us 0.00% 2.284s 142.775ms 16
aten::linalg_solve_ex 0.00% 37.000us 75.04% 5.708s 475.648ms 0.000us 0.00% 1.513s 126.102ms 12
autograd::engine::evaluate_function: LinalgSolveExBa... 0.00% 122.000us 62.86% 4.781s 683.040ms 0.000us 0.00% 1.275s 182.126ms 7
LinalgSolveExBackward0 0.00% 76.000us 62.85% 4.781s 683.008ms 0.000us 0.00% 1.275s 182.124ms 7
aten::linalg_lu_solve 9.34% 710.739ms 35.82% 2.724s 55.427us 643.114ms 33.69% 831.927ms 16.926us 49152
aten::linalg_lu_factor_ex 7.28% 553.883ms 20.95% 1.593s 27.784us 661.250ms 34.64% 721.665ms 12.585us 57344
aten::_linalg_slogdet 4.59% 349.122ms 72.77% 5.535s 658.539us 0.000us 0.00% 677.273ms 80.580us 8405
void getf2_cta_32x32<double, double>(int, int, int, ... 0.00% 0.000us 0.00% 0.000us 0.000us 540.579ms 28.32% 540.579ms 9.427us 57344
void trsm_batch_left_lower_kernel<double>(cublasTrsm... 0.00% 0.000us 0.00% 0.000us 0.000us 277.594ms 14.54% 277.594ms 5.648us 49152
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 7.607s
Self CUDA time total: 1.909s
functorch
also prompts me with a UserWarning
that batching rules do not exists for aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
and it defaults to a for-loop which will affect performance.
~/pytorch_nightly/debug/per-sample-elocal.py:49: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
sgn, logabs = torch.linalg.slogdet(mat)
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_solve_ex. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
The full script to reproduce this error can be found below.
import torch
import torch.nn as nn
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
import functorch
from functorch import jacrev, jacfwd, hessian, make_functional, vmap, grad
import time
_ = torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
#version info
print("PyTorch version: ", torch.__version__)
print("CUDA version: ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)
#time with torch synchronization
def sync_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
class model(nn.Module):
def __init__(self, num_inputs, num_hidden):
super(model, self).__init__()
self.num_inputs=num_inputs
self.func = nn.Tanh()
self.fc1 = nn.Linear(2, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_inputs)
def forward(self, x):
"""
Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
"""
idx=len(x.shape) #creates args for repeat if vmap is used or not
rep=[1 for _ in range(idx)]
rep[-2] = self.num_inputs
g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
f = torch.cat((x,g), dim=-1)
h = self.func(self.fc1(f))
mat = self.fc2(h)
sgn, logabs = torch.linalg.slogdet(mat)
return sgn, logabs
#=================================================================================================#
#Profile code for N=1 to 6
#=================================================================================================#
B=4096 #batch
N=2 #input nodes
H=128 #number of hidden nodes
device=torch.device("cuda")
for N in [1,1,2,3,4,5,6]:
net = model(N, H)
net = net.to(device)
x = torch.randn(B,N,1,device=device) #input data
fnet, params = make_functional(net)
def logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def kinetic_functorch(params, X):
#do once, and re-use via has_aux?
calc_jacobian = jacrev(logabs, argnums=1)
#can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1)
return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)
#per-sample gradients for local energy w.r.t params via FuncTorch
t1=sync_time()
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
t2=sync_time()
print("N: %2i | Walltime: %6.4f (s)" % (N, t2-t1))
#=================================================================================================#
#Profile code for N=4
#=================================================================================================#
N=4
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
net = model(N, H)
net = net.to(device)
x = torch.randn(B,N,1,device=device) #input data
fnet, params = make_functional(net)
def logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def kinetic_functorch(params, X):
#do once, and re-use via has_aux?
calc_jacobian = jacrev(logabs, argnums=1)
#can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1)
return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)
#per-sample gradients for local energy w.r.t params via FuncTorch
t1=sync_time()
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
t2=sync_time()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Thanks in advance! :)
cc @samdow -- were these the ones you were planning on adding? I haven't checked if these fall into the low-hanging-fruit category
Yep! I have changes for all of these functions locally so let me pick them onto the new repo and put up some PRs
@samdow did https://github.com/pytorch/pytorch/pull/82177 cover all of these?
Not quite! Discussed a bit offline but @AlphaBetaGamma96 I'm hoping to get https://github.com/pytorch/pytorch/pull/82814 in soon and then I'll double check that we see speedups for your example (thanks for the repro!). Sorry for the delay--we ran into some AD related bugs because of adding these rules
No need to apologize for the delay @samdow, thanks for solving the batch rule! Fingers crossed it all works!
Hi @AlphaBetaGamma96! Just wanted to let you know that linalg_solve just landed https://github.com/pytorch/pytorch/pull/82814. I tested locally that this example ran significantly faster after the fix than before the fix (exact numbers around the same order of magnitude to yours)
Thanks for the issue and thanks for your patience as we worked through some AD issues
Hi @samdow, thanks for fixing this issue! A bit of a silly question, but I remember reading somewhere that functorch is being merged directly into pytorch (if that's the correct phrase). So, would I have to just download the latest nightly of pytorch (and now just ignore functorch), or do I just update functorch to its latest version as well as pytorch?
Not a silly question, there's been a lot of change in the past couple of weeks. Yes the main development of functorch is being done in pytorch/pytorch. So if you're building from source, you'll want to build pytorch master, cd into the functorch directory, and then build functorch.
Or (this workflow may occasionally break for ~a day and I don't do this, so let me know if it doesn't work for you), you can download pytorch nightly and then build the newest version of functorch against that. Options for getting functorch this way are either (1) downloading pytorch, cd-ing into the functorch directory, and building that or (2) downloading the functorch repo and building that (we aren't developing here but have a read-only sync for the moment)
Currently, sadly, it doesn't just work to download pytorch nightly and get the newest functorch with it. We have people actively working on making this work and we can keep you updated
Let me know if any of that doesn't make sense!
Hi @samdow, so just to check, I need to download the latest pytorch nightly from https://pytorch.org/get-started/locally/ and then install functorch from source (from https://github.com/pytorch/functorch#installing-functorch-main), and that should be ok? (Assuming I've correctly understood what you've said)
EDIT: That seems to have worked
For completeness, I thought I'd share the results for the latest nightly version
PyTorch version: 1.13.0.dev20220820
CUDA version: 11.6
FuncTorch version: 0.3.0a0+86a9049
N: 1 | Walltime: 0.4445 (s)
N: 1 | Walltime: 0.0107 (s)
N: 2 | Walltime: 0.0928 (s)
N: 3 | Walltime: 0.1296 (s)
N: 4 | Walltime: 0.1261 (s)
N: 5 | Walltime: 0.1634 (s)
N: 6 | Walltime: 0.2002 (s)
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:300] Completed Stage: Collection
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::bmm 0.52% 620.000us 2.01% 2.397ms 40.627us 88.138ms 73.46% 172.966ms 2.932ms 59
aten::matmul 0.55% 662.000us 2.93% 3.503ms 61.456us 0.000us 0.00% 158.351ms 2.778ms 57
aten::mm 0.39% 468.000us 1.95% 2.336ms 48.667us 12.092ms 10.08% 87.588ms 1.825ms 48
autograd::engine::evaluate_function: AddmmBackward0 0.07% 86.000us 1.27% 1.517ms 151.700us 0.000us 0.00% 38.730ms 3.873ms 10
volta_dgemm_64x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 37.369ms 31.14% 37.369ms 3.737ms 10
AddmmBackward0 -0.19% -225.000us 0.90% 1.071ms 107.100us 0.000us 0.00% 36.673ms 3.667ms 10
volta_dgemm_128x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 33.590ms 27.99% 33.590ms 5.598ms 6
autograd::engine::evaluate_function: BmmBackward0 0.03% 30.000us 0.66% 791.000us 131.833us 0.000us 0.00% 32.940ms 5.490ms 6
BmmBackward0 -0.10% -122.000us 0.61% 734.000us 122.333us 0.000us 0.00% 32.758ms 5.460ms 6
autograd::engine::evaluate_function: LinalgSolveExBa... 0.02% 24.000us 63.43% 75.798ms 18.950ms 0.000us 0.00% 15.985ms 3.996ms 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 119.494ms
Self CUDA time total: 119.989ms
Hi @samdow, apologies for re-opening this issue but could the torch.linalg.lu*
functions also be added for a batching rule? It seems that when torch.linalg.slogdet
is called it calls torch.linalg.lu
for the decomposition (in order to perform the determinant call within the log-domain) which doesn't seem to have a batching rule. I've posted the warning below and it highlights aten::linalg_lu_solve
as not having a batching rule.
~/main.py:201: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
sgns, logabss = torch.slogdet(matrices * torch.exp(log_envs))
~/anaconda3/envs/pytorch_nightly_env/lib/python3.10/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
I added linalg_lu_solve yesterday, could you reinstall the latest pytorch nightly and try again?
Hi @zou3519, sorry for the late response. I've installed the latest pytorch nightly and the UserWarning
isn't there anymore. Thank you!
EDIT: removed issue with functorch install. Fresh install works fine!
Hi @zou3519, I've just noticed that if torch.slogdet
(instead of torch.linalg.slogdet) is used it defaults to a for-loop and stack, but torch.linalg.slogdet
works fine as expected. I assume torch.slogdet
is going to be removed in a future update (as it's moving the linalg library), but I thought I'd mention it here in case this problem emerges in other situations with other functions.
UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /opt/conda/conda-bld/pytorch_1664781140419/work/aten/src/ATen/functorch/BatchedFallback.cpp:82.)
I think I know what is going on here, will fix soon. EDIT: fix over at https://github.com/pytorch/pytorch/pull/86815 . Thanks as always for reporting bugs, @AlphaBetaGamma96.
It doesn't look like we've actually deprecated torch.slogdet in favor of torch.linalg.slogdet, is that right @lezcano ? (https://pytorch.org/docs/1.13/generated/torch.slogdet.html?highlight=torch+slogdet#torch.slogdet). In that case we do want vmap support for both operators since users aren't being directed to use torch.linalg.slogdet over torch.slogdet.
We haven't deprecated it. It's left there as an alias:
// Alias
std::tuple<Tensor, Tensor> slogdet(const Tensor& A) {
return at::linalg_slogdet(A);
}
std::tuple<Tensor&, Tensor&> slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
return at::linalg_slogdet_out(sign, logabsdet, A);
}