functorch
functorch copied to clipboard
memory_efficient_fusion leads to RuntimeError for higher-order gradients calculation. RuntimeError: You are attempting to call Tensor.requires_grad_()
Hi All,
I've tried improving the speed of my code via using memory_efficient_fusion, however, it leads to Tensor.requires_grad_() error and I have no idea why. The error is as follows,
RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.
I've attached a 'minimal' reproducible example of this behaviour below. I've tried a few different things but nothing's seems to have worked. I did see in #840 memory_efficient_fusion is done within a context manager, however, when using that I get the same error.
Thanks in advance!
EDIT: When I tried running it, it tried to use the networkx package but that wasn't installed by default. So, I had to manually install that (which wasn't a problem), just not sure if installing from source should also include install those packages as well!
import torch
from torch import nn
import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion
import time
_ = torch.manual_seed(1234)
#version info
print("PyTorch version: ", torch.__version__)
print("CUDA version: ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)
#=============================================#
#time with torch synchronization
def sync_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
class model(nn.Module):
def __init__(self, num_inputs, num_hidden):
super(model, self).__init__()
self.num_inputs=num_inputs
self.func = nn.Tanh()
self.fc1 = nn.Linear(2, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_inputs)
def forward(self, x):
"""
Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
"""
idx=len(x.shape) #creates args for repeat if vmap is used or not
rep=[1 for _ in range(idx)]
rep[-2] = self.num_inputs
g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
f = torch.cat((x,g), dim=-1)
h = self.func(self.fc1(f))
mat = self.fc2(h)
sgn, logabs = torch.linalg.slogdet(mat)
return sgn, logabs
#=============================================#
B=4096 #batch
N=2 #input nodes
H=64 #number of hidden nodes
device = torch.device('cuda')
x = torch.randn(B, N, 1, device=device) #input data
net = model(N, H) #our model
net=net.to(device)
fnet, params = make_functional(net)
def calc_logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def calc_dlogabs_dx(params, x):
dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
return dlogabs_dx, dlogabs_dx #return aux
def local_kinetic_from_log_vmap(params, x):
d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
_local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-4,-2).sum() + dlogabs_dx.pow(2).sum())
return _local_kinetic
#memory efficient fusion here
#with torch.jit.fuser("fuser2"): is this needed (from functorch/issues/840)
ps_elocal = grad(local_kinetic_from_log_vmap, argnums=0)
ps_elocal_fusion = memory_efficient_fusion(grad(local_kinetic_from_log_vmap, argnums=0))
#ps_elocal_fusion(params, x) #no vmap attempt (throws size mis-match error)
t1=sync_time()
vmap(ps_elocal, in_dims=(None, 0))(params, x) #works fine
t2=sync_time()
vmap(ps_elocal_fusion, in_dims=(None, 0))(params, x) #error (crashes on this line)
t3=sync_time()
print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion): %4.2e (s)",t3-t2)
So it seems the solution is to place the vmap called within the memory_efficient_fusion call like so,
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))
then just call,
ps_elocal_fusion(params, x) #works now.
Although, it's about an order of magnitude slower than the non-memory_efficient_fusion version.
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)) #0.454 (s)
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))) #5.804 (s)
EDIT: versions for reference;
PyTorch version: 1.13.0.dev20220820
CUDA version: 11.6
FuncTorch version: 0.3.0a0+86a9049
cc @Chillee @anijain2305
Any thoughts? In particular re: why memory_efficient_fusion made the the final case slower
I thought I'd also mention that memory_efficient_fusion fails if a scalar is included. For example, using this function (which differs from the original value of -0.5 * factor)
def local_kinetic_from_log_vmap(params, x):
d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
_local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-4,-2).sum() + dlogabs_dx.pow(2).sum())
return _local_kinetic
returns the following error,
RuntimeError: aten::_to_copy() Expected a value of type 'Tensor' for argument 'self' but instead found type 'float'.
Position: 0
Value: -0.5
Declaration: aten::_to_copy(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int? memory_format=None) -> Tensor
Cast error details: Unable to cast -0.5 to Tensor