fast-autoaugment
fast-autoaugment copied to clipboard
Pyramidnet Issue
Hi,
I am currently trying to utilize the PyramidNet + ShakeDrop. However I am getting the following error:
RuntimeError: Output 0 of ShakeDropFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can remove this warning by cloning the output of the custom Function.
If I try to fix the error by changing some lines, the memory usage seems to increase a lot. So I was wondering whether you also encountered the following errors.
Thank you!
I'm experiencing the same issue.
I figured out a way to prevent this. If you look at pyramidnet.py
lines 106-120:
if residual_channel != shortcut_channel:
padding = torch.autograd.Variable(
torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0],
featuremap_size[1]).fill_(0))
out += torch.cat((shortcut, padding), 1)
It is creating variables on fly, which hopefully autograd
will clean up on the backward pass. So there are a few issues with this approach:
- If you don't have a backward pass, and you just want to forward some inputs, the extra memory allocated here will never get cleaned up.
- If you have
torch.nn.DataParallel
wrapped around it, let's say you are running on 4 gpus, (I'm guessing) that the autograd will clean-up only the last variable, and 3 of them remain in the memory. That's why you get memory leak (actual RAM memory, not gpu memory). I replaced these lines with the following and that seems to have solved the problem:
if residual_channel != shortcut_channel:
out[:, :shortcut.size(1)] = out[:, :shortcut.size(1)] + shortcut
else:
out = out + shortcut
To be honest, sometimes this fails too, but it is not because of the memory leak, and I have not figured out why this approach fails, but it would be helpful if anyone knows why.
Update: I figured out the problem. Here's one that does not throw any errors to the best of my knowledge and also does not leak memory:
if residual_channel != shortcut_channel:
out = out.clone()
out[:, :shortcut.size(1)] = out[:, :shortcut.size(1)] + shortcut
else:
out = out + shortcut