revlib icon indicating copy to clipboard operation
revlib copied to clipboard

revlib can not work in torch amp.

Open JAYatBUAA opened this issue 1 year ago • 8 comments

dear authors, when using revlib in torch amp, it reports error as follow:

Traceback (most recent call last): File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs)) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward return F.conv3d( RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

JAYatBUAA avatar Jul 24 '23 07:07 JAYatBUAA

dear authors, how to solve it, thanks in advance

JAYatBUAA avatar Jul 24 '23 07:07 JAYatBUAA

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome. Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

ClashLuke avatar Jul 24 '23 10:07 ClashLuke

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome. Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

dear author, I guess the error happens in the RevResNet backward pass where feature map dtype (float16) is not match to the conv weights dtype (float32) and this will not happen in forward pass, because the forward pass is warpped in torch.cuda.amp.autocast() context, where the conv weights dtype will automatically convert to the half.

JAYatBUAA avatar Jul 24 '23 19:07 JAYatBUAA

Do you have a minimal example to reproduce the error?

ClashLuke avatar Jul 25 '23 06:07 ClashLuke

Do you have a minimal example to reproduce the error?

Due to an important deadline recently, I'll try to give you a reply as soon as possible. Thanks for your help.

JAYatBUAA avatar Jul 25 '23 07:07 JAYatBUAA

when loss.backward() warpped in torch.cuda.amp.autocast() context, this error is not reported.

JAYatBUAA avatar Jul 31 '23 13:07 JAYatBUAA

Please share a minimal script to reproduce this error. I'll be able to take it from there.

The best next action would be a PR with a unit test for torch amp.
Alternatively, RevLib is open to contributions. You're welcome to submit a PR for the fix :)

ClashLuke avatar Aug 01 '23 11:08 ClashLuke

Really thanks for your RevLib. I hope to contribute mysellf once I have enough time.

JAYatBUAA avatar Aug 01 '23 11:08 JAYatBUAA