apex icon indicating copy to clipboard operation
apex copied to clipboard

torch.einsum does not cast tensors when using apex.amp

Open aRI0U opened this issue 4 years ago • 3 comments

Hi, I am using apex.amp to reduce the size of my model. I followed the instructions given here but I got an error using torch.einsum.

Here is a minimal snippet:

import apex.amp as amp
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(7, 3)
        self.p = nn.Parameter(torch.randn(3, 4))

    def forward(self, x):
        p = self.p.repeat(x.size(0), 1, 1)
        print('p.dtype:', p.dtype)

        print('x.dtype:', x.dtype)
        x = self.fc(x)
        print('x.dtype:', x.dtype)

        y = torch.bmm(x, p)
        print('y.dtype:', y.dtype)

        z = torch.einsum('bik,bkj->bij', (x,p))
        print('z.dtype:', z.dtype)

        assert torch.equal(y, z)
        return z

model = Model()
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters())

model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

x = torch.randn(2, 5, 7).cuda()

y = model(x)

The output is:

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Warning:  multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback.  Original ImportError was: ModuleNotFoundError("No module named 'amp_C'")
p.dtype: torch.float32
x.dtype: torch.float32
x.dtype: torch.float16
y.dtype: torch.float16
Traceback (most recent call last):
  File "snippet.py", line 36, in <module>
    y = model(x)
  File "/home/alain/miniconda3/envs/nes/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "snippet.py", line 22, in forward
    z = torch.einsum('bik,bkj->bij', (x,p))
  File "/home/alain/miniconda3/envs/nes/lib/python3.8/site-packages/torch/functional.py", line 292, in einsum
    return _VF.einsum(equation, operands)
RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm

Changing the opt level to O2 or O3 avoids this error since model.p is then cast earlier, but it is just ducking the issue. Could you check if you face the same problem when you run my code and tell me if you have a solution to this?

Thanks a lot

aRI0U avatar Jun 26 '20 13:06 aRI0U

Hello,

I've met the same problem and I don't think torch.einsum autocasts input tensors. However, I've found that using torch.cuda.amp solves the issue. Below is the code that I used for testing. https://gist.github.com/hyukyu/a03816f120335f1f8ffb8b8ccca17eb7

ghost avatar Aug 27 '20 06:08 ghost

Have you solved the problem?@aRI0U

HuiqinWu avatar Mar 21 '22 13:03 HuiqinWu

Hi, I did not solve it but found 2 ways to circumvent it: First one is to avoid using torch.einsum, which can be replaced by axis permutations and tensor multiplications. The other one is to use PyTorch native implementation of AMP, which is implemented since PyTorch 1.6.0 if I remember well. This is definitely not a viable solution but at least it works

aRI0U avatar Mar 21 '22 13:03 aRI0U