revlib
revlib copied to clipboard
revlib can not work in torch amp.
dear authors, when using revlib in torch amp, it reports error as follow:
Traceback (most recent call last): File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs)) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward return F.conv3d( RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
dear authors, how to solve it, thanks in advance
Unfortunately, I don't know. I don't use torch.amp
myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome.
Just so we're on the same page: torch.amp
doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\
Unfortunately, I don't know. I don't use
torch.amp
myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome. Just so we're on the same page:torch.amp
doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\
dear author, I guess the error happens in the RevResNet backward pass where feature map dtype (float16) is not match to the conv weights dtype (float32) and this will not happen in forward pass, because the forward pass is warpped in torch.cuda.amp.autocast() context, where the conv weights dtype will automatically convert to the half.
Do you have a minimal example to reproduce the error?
Do you have a minimal example to reproduce the error?
Due to an important deadline recently, I'll try to give you a reply as soon as possible. Thanks for your help.
when loss.backward() warpped in torch.cuda.amp.autocast() context, this error is not reported.
Please share a minimal script to reproduce this error. I'll be able to take it from there.
The best next action would be a PR with a unit test for torch amp.
Alternatively, RevLib is open to contributions. You're welcome to submit a PR for the fix :)
Really thanks for your RevLib. I hope to contribute mysellf once I have enough time.