functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Having BatchNorm2D raises in-place operation error

Open MaxH1996 opened this issue 3 years ago • 7 comments

I am working on a project which requires me to calculate the trace of the Hessian of standard ResNet architectures. To this end I am using the Hutchinson method, which requires me to form the Hessian vector product. I am currently using ResNet18 as implemented in torchvision. This entails BatchNorm2D operations with track_running_stats=True. If I set track_running_stats=False I can execute the following code without any problems:

import torch
from functorch import make_functional_with_buffers
from functorch import grad, jvp, vjp



criterion = torch.nn.CrossEntropyLoss()

def rademacher(shape, dtype=torch.float32, device='cuda'):
    rand = ((torch.rand(shape) < 0.5)) * 2 - 1
    return rand.to(dtype).to(device)

def loss(params, batch, fn, buffers):
    x,y = batch
    out = fn(params, buffers, x)
    loss = criterion(out,y)
    return loss

def hvp(params, batch, v, fn, buffers):
    loss_fn = lambda x: loss(x, batch, fn, buffers)
    _, vjp_fn = vjp(grad(loss_fn), params)
    return  vjp_fn(v)[0]

def hutchinson(net, x, y, iterations, device='cuda'):
    
    fn , params, buffers = make_functional_with_buffers(net)
    params = [p.data for p in params]

    trace = 0
    V = iterations
    for _ in range(V):
        v = [rademacher(p.shape, device=device) for p in params]
        Hv = hvp(params, (x,y), v, fn, buffers)

        for v, Hv in zip(v, Hv):
            vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten())
            trace += vHv / V
    return trace

where net is my ResNet18 and x and y are my images and labels respectively. However, if I set ```track_running_stats=True`` I get the following error:

RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.

I have encountered the same problem when computing the NTK using the example given in the functorch documentation. Is there a quick work around to this problem?

Thanks in advance.

MaxH1996 avatar Jul 26 '22 10:07 MaxH1996

I can across this which explains the issue, and gives some solutions to it! https://pytorch.org/functorch/stable/batch_norm.html

AlphaBetaGamma96 avatar Jul 26 '22 14:07 AlphaBetaGamma96

Hey @AlphaBetaGamma96, thanks for your response. Unfortunately I do need running statistics as I am using stochastic weight averaging. When using the pytorch lightning module for SWA, for example, I get an error at evaluation. If I understand correctly, the site you linked suggests to open an issue if I have a use case which requires running stats.

MaxH1996 avatar Jul 26 '22 14:07 MaxH1996

Thanks @AlphaBetaGamma96 for linking to that, that's the right link for this. I wanted the error message to link to be more insightful and link to that so I'll investigate why it's not coming up

@MaxH1996 thanks for this issue and for clarifying that you do need running stats. I'm not very familiar with the Hutchinson method, so would you mind helping me clarify: Do you need the running stats to update while you're computing the hvp? Or if the intention that you would train the model and then use the stats computed during that to compute the hvp?

samdow avatar Jul 26 '22 14:07 samdow

Hi @MaxH1996, that's unfortunate. Potentially you could copy the source for BatchNorm2D and replace all in-place commands with their out-of-place equivalent? (If that's at all possible)

@samdow Hutchinson estimator is a way to stochastically estimate the trace of a matrix. There's a decent explanation of it in the backpack docs that I came across here. (Although it might be better for @MaxH1996 to explain their use case).

image

AlphaBetaGamma96 avatar Jul 26 '22 15:07 AlphaBetaGamma96

@samdow Essentially what @AlphaBetaGamma96 has posted about the Hutchinson method is what I am using. During training I compute the trace of the Hessian every n-th step and then do some further calculations. At the same time I am also using stochastic weight averaging (SWA) which (at least in the PyTorch-Lightning implementation) requires running stats. So really this is where my issue comes from.

@AlphaBetaGamma96 I am not all too familiar with the source code of BatchNorm2D, I could do that, but I was maybe hoping for something I could change on the functorch-side of the implementation. But this could work if nothing else does.

MaxH1996 avatar Jul 26 '22 15:07 MaxH1996

@MaxH1996 Neither am I, although I did have a quick look ( source here ) at it and there are indeed some in-place operations (I'm not sure if that's need by design) but perhaps an out-of-place equivalent might be easy to make?

AlphaBetaGamma96 avatar Jul 26 '22 20:07 AlphaBetaGamma96

Hi @MaxH1996. Not sure if you still need a solution for it, but the only in-place operation that you need to change for BatchNorm2D (in torch/nn/modules/batchnorm.py) is self.num_batches_tracked.add_(1) . Simply replacing it with self.num_batches_tracked = self.num_batches_tracked + 1 fixed the issue in ResNets for me.

HamedHemati avatar Jun 07 '23 14:06 HamedHemati