functorch
functorch copied to clipboard
Having BatchNorm2D raises in-place operation error
I am working on a project which requires me to calculate the trace of the Hessian of standard ResNet architectures. To this end I am using the Hutchinson method, which requires me to form the Hessian vector product. I am currently using ResNet18 as implemented in torchvision. This entails BatchNorm2D operations with track_running_stats=True. If I set track_running_stats=False I can execute the following code without any problems:
import torch
from functorch import make_functional_with_buffers
from functorch import grad, jvp, vjp
criterion = torch.nn.CrossEntropyLoss()
def rademacher(shape, dtype=torch.float32, device='cuda'):
rand = ((torch.rand(shape) < 0.5)) * 2 - 1
return rand.to(dtype).to(device)
def loss(params, batch, fn, buffers):
x,y = batch
out = fn(params, buffers, x)
loss = criterion(out,y)
return loss
def hvp(params, batch, v, fn, buffers):
loss_fn = lambda x: loss(x, batch, fn, buffers)
_, vjp_fn = vjp(grad(loss_fn), params)
return vjp_fn(v)[0]
def hutchinson(net, x, y, iterations, device='cuda'):
fn , params, buffers = make_functional_with_buffers(net)
params = [p.data for p in params]
trace = 0
V = iterations
for _ in range(V):
v = [rademacher(p.shape, device=device) for p in params]
Hv = hvp(params, (x,y), v, fn, buffers)
for v, Hv in zip(v, Hv):
vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten())
trace += vHv / V
return trace
where net is my ResNet18 and x and y are my images and labels respectively. However, if I set ```track_running_stats=True`` I get the following error:
RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.
I have encountered the same problem when computing the NTK using the example given in the functorch documentation. Is there a quick work around to this problem?
Thanks in advance.
I can across this which explains the issue, and gives some solutions to it! https://pytorch.org/functorch/stable/batch_norm.html
Hey @AlphaBetaGamma96, thanks for your response. Unfortunately I do need running statistics as I am using stochastic weight averaging. When using the pytorch lightning module for SWA, for example, I get an error at evaluation. If I understand correctly, the site you linked suggests to open an issue if I have a use case which requires running stats.
Thanks @AlphaBetaGamma96 for linking to that, that's the right link for this. I wanted the error message to link to be more insightful and link to that so I'll investigate why it's not coming up
@MaxH1996 thanks for this issue and for clarifying that you do need running stats. I'm not very familiar with the Hutchinson method, so would you mind helping me clarify: Do you need the running stats to update while you're computing the hvp? Or if the intention that you would train the model and then use the stats computed during that to compute the hvp?
Hi @MaxH1996, that's unfortunate. Potentially you could copy the source for BatchNorm2D and replace all in-place commands with their out-of-place equivalent? (If that's at all possible)
@samdow Hutchinson estimator is a way to stochastically estimate the trace of a matrix. There's a decent explanation of it in the backpack docs that I came across here. (Although it might be better for @MaxH1996 to explain their use case).

@samdow Essentially what @AlphaBetaGamma96 has posted about the Hutchinson method is what I am using. During training I compute the trace of the Hessian every n-th step and then do some further calculations. At the same time I am also using stochastic weight averaging (SWA) which (at least in the PyTorch-Lightning implementation) requires running stats. So really this is where my issue comes from.
@AlphaBetaGamma96 I am not all too familiar with the source code of BatchNorm2D, I could do that, but I was maybe hoping for something I could change on the functorch-side of the implementation. But this could work if nothing else does.
@MaxH1996 Neither am I, although I did have a quick look ( source here ) at it and there are indeed some in-place operations (I'm not sure if that's need by design) but perhaps an out-of-place equivalent might be easy to make?
Hi @MaxH1996. Not sure if you still need a solution for it, but the only in-place operation that you need to change for BatchNorm2D (in torch/nn/modules/batchnorm.py) is self.num_batches_tracked.add_(1) . Simply replacing it with self.num_batches_tracked = self.num_batches_tracked + 1 fixed the issue in ResNets for me.