Retinanet-Pytorch icon indicating copy to clipboard operation
Retinanet-Pytorch copied to clipboard

为什么会出现这个错误RuntimeError: Expected object of scalar type float but got scalar type double for argument 'other'

Open jfdyft opened this issue 3 years ago • 14 comments

jfdyft avatar Oct 02 '20 15:10 jfdyft

我也出现了

lmcltj avatar Oct 21 '20 08:10 lmcltj

因为,数据是双精度的啊,而需要的是单精度的。 如果你需要模型是双精度的, 用 model = model.double() 改成双精度就行。 或者 把你数据 data=data.float() 改成单精度的

yatengLG avatar Oct 21 '20 09:10 yatengLG

请问这个错误具体应该怎么改呢?我尝试改了几个地方的,但是还是有错误

taochx avatar Oct 25 '20 12:10 taochx

@taochongxin 模型权重和数据均是有精度的。 一般来讲,自定义的模型,如果没有特别指定,一般都是单精度的,也就是float, 主要问题应该是出在你的数据上,你把你的数据,的dtype打印出来,看看是什么。 然后把模型和数据的 dtype改成一样的就可以了。 这个问题也一般会出现在损失函数中,反正涉及到权重的地方都会有精度问题。你统一了就好了

yatengLG avatar Oct 26 '20 00:10 yatengLG

@yatengLG 多谢回复,我使用的数据集也是voc2007,我再尝试改一下,谢谢

taochx avatar Oct 26 '20 01:10 taochx

我都改成双精度了,可以跑通

lmcltj avatar Oct 27 '20 03:10 lmcltj

我也遇到这个问题了

dfy888 avatar Dec 04 '20 02:12 dfy888

@dfy888 精度问题,模型精度和数据精度不一致。

yatengLG avatar Dec 04 '20 02:12 yatengLG

型精度和数据精度不一致

AttributeError: 'DataLoader' object has no attribute 'dtype' 到底改哪里啊? 模型 数据都没有 dtype属性? 感谢作者大大的回复

dfy888 avatar Dec 04 '20 02:12 dfy888

@dfy888

for param in net.parameters(): print(param.dtype) 查看模型参数的 数值精度

然后你数据也是同样的,只要是tensor肯定有dtype的

yatengLG avatar Dec 04 '20 02:12 yatengLG

@dfy888 dataloader 是个 数据加载器的类,不是你的输入数据。 你训练的时候不是,有个 for image, label in dataloader: 么 , 这里拿出来的才是数据啊

yatengLG avatar Dec 04 '20 02:12 yatengLG

关于这个报错,统一回复

原因是

  • 模型与数据类型不一致,
  • " Expected object of scalar type float but got scalar type double for argument 'other' ",应该输入的是float类型,但是输入的参数是double类型的。

具体查看方法是

  • for param in net.parameters(): print(param.dtype) 打印你模型权重的dtype类型(模型是没有dtype属性的,模型的数值类型具体来讲是模型权重参数的数值类型)。`
  • for image, label in dataloader : print(image.dtype, label,dtype)
    查看输入数据以及标签的数值类型。

解决方法:

  1. 将模型调整为双精度 net.double()
  2. 调整为单精度 net.float()
  3. 数据同操作

yatengLG avatar Dec 04 '20 03:12 yatengLG

@dfy888 dataloader 是个 数据加载器的类,不是你的输入数据。 你训练的时候不是,有个 for image, label in dataloader: 么 , 这里拿出来的才是数据啊

我到不了那里 刚刚定位了一下是这里的问题: print("传进来的类型torch:gt_boxes={};gt_labels={}".format(gt_boxes.dtype, gt_labels.dtype)) boxes, labels = assign_priors(gt_boxes, gt_labels, self.corner_form_priors, self.iou_threshold)

dfy888 avatar Dec 04 '20 03:12 dfy888

终于跑起来了 我自己的数据集 问题在这:

print("这里的错啊:", gt_boxes.dtype, gt_boxes.size())

# print("这里的错啊:", corner_form_priors.dtype, corner_form_priors.size())
# 所以我经常遇到'float64' * 'float32' 报错
corner_form_priors = corner_form_priors.float()  # 'float32' 我加的这句

ious = iou_of(gt_boxes.unsqueeze(0), corner_form_priors.unsqueeze(1))

希望可以帮到跟我一样的朋友 再次感谢作者的注释

dfy888 avatar Dec 04 '20 03:12 dfy888