memcnn
memcnn copied to clipboard
ReLU before Inversible layers
- MemCNN version: 1.5.1
- PyTorch version: 1.10.1
- Python version: 3.8.12
- Operating System: Windows 10
Description
Hi, thank you for making this super useful library publicly available. I tried to use it but I have encountered a weird problem. It would be great if I could get some help here.
What I Did
Here's the minimal code to reproduce the error
import torch
import torch.nn as nn
import memcnn
class Example(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=1, out_channels=4, kernel_size=1),
nn.ReLU()
)
iconv1 = memcnn.AdditiveCoupling(
Fm=nn.Sequential(nn.Conv3d(4//2, 4//2, kernel_size=3, padding=1),
nn.ReLU()),
Gm=nn.Sequential(nn.Conv3d(4//2, 4//2, kernel_size=3, padding=1),
nn.ReLU()),
)
self.iconv1_wrapper = memcnn.InvertibleModuleWrapper(fn=iconv1, keep_input=False, keep_input_inverse=False)
def forward(self,x):
out = self.conv1(x)
out = self.iconv1_wrapper(out)
return out
model = Example()
X = torch.randn(1,1,16,16,16)
Y = model(X)
Y.sum().backward()
Here's the error message and stack trace
Traceback (most recent call last):
File "test.py", line 33, in <module>
Y.sum().backward()
File "C:\Users\50139\.julia\conda\3\envs\kd\lib\site-packages\torch\_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "C:\Users\50139\.julia\conda\3\envs\kd\lib\site-packages\torch\autograd\__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 4, 16, 16, 16]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
However, if I remove ReLU in self.conv1 or make self.iconv2_wrapper keep_input=True, the error goes away.
I think one possible cause is that the memcnn.InvertibleModuleWrapper is modifying the output of ReLU() somewhere, which is expected. But then why doesn't modifying the output of the conv layer trigger the error?
Thank you so much for your help!
Hi, thank you for your interest in MemCNN. You have identified an interesting behavior, which wasn't present in PyTorch 1.7.0 (last tested version for MemCNN). Apparently, it only happens when the InvertibleModuleWrapper comes right after operations like the ReLU (which strangely enough isn't an in-place operation by default). If for example, I put another Convolution layer after the ReLU it seems to work again.
For now, this behavior can be circumvented by cloning the output in the forward pass, directly after the ReLU and before the invertible wrapper like this:
def forward(self,x):
out = self.conv1(x).clone()
out = self.iconv1_wrapper(out)
return out
This should at least make your code work. Keep in mind the additional memory overhead for the clone operation, so these should only be used when necessary.