apex
apex copied to clipboard
RuntimeError: expected scalar type Float but found Half
def __init__(self, chi, cho):
super(DeformConv, self).__init__()
self.actf = nn.Sequential(
nn.BatchNorm2d(cho, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
self.conv = DCN(chi, cho, kernel_size=(3, 3), stride=1, padding=1, dilation=1, deformable_groups=1)
def forward(self, x):
x = self.conv(x)
x = self.actf(x)
return x
My model exsits a DCN module which compiled by c++. when I use amp.initialize(model, optimizer, opt_level="O1"), RuntimeError has happened(expected scalar type Float but found Half) in x=self.conv(x). I try to use x=self.conv(x.float()) to convert type, but not useful.
I have the same problem Have you solved this problem?
same problem in pytorch 1.6
I have the same problem Have you solved this problem?
Check the input and convert all float16 parameters to float32, like x.float()
I have the same problem Have you solved this problem?
Check the input and convert all float16 parameters to float32, like x.float()
Could you please tell me how to do it efficiently?
i got the same error but strange thing is when i replace my nn.MSELoss to nn.BCEWithLogitsloss code runs perfectly error is coming when i use mseloss how to solve this error
This error occurs when the 2 matrices you are multiplying are not of same dtype.
Half means dtype = torch.float16 while, Float means dtype = torch.float32
to resolve the error simply cast your model weights into float32
for param in model.parameters():
# Check if parameter dtype is Half (float16)
if param.dtype == torch.float16:
param.data = param.data.to(torch.float32)
This error occurs when the 2 matrices you are multiplying are not of same dtype.
Half means dtype = torch.float16 while, Float means dtype = torch.float32
to resolve the error simply cast your model weights into float32
for param in model.parameters(): # Check if parameter dtype is Half (float16) if param.dtype == torch.float16: param.data = param.data.to(torch.float32)
I know I am way behind here. Could you tell me where to put this code? Thanks