PseCo icon indicating copy to clipboard operation
PseCo copied to clipboard

cls_head 计算特征时维度不匹配?

Open lichen14 opened this issue 1 year ago • 3 comments

https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277

我在跑这行代码的时候总是会遇到embeddings和prompts特征维度不匹配的问题,导致 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/models.py#L25 这里无法直接矩阵乘

比如:

  • 计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277 时,embeddings 和 prompts.unsqueeze(1) 的shape 分别是 torch.Size([8192, 3, 512]) torch.Size([8192, 1, 3, 512]) ,这样直接计算会显存爆炸,需要384G ** 我试过把prompts.unsqueeze(1) 去掉unsqueeze(1) ,这样不报显存OOM的bug了,但是下面的代码会报错:
  • 计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L283 时,embeddings 和 prompts的shape是torch.Size([1536, 3, 512]) torch.Size([1536, 512]),这时候就会有 RuntimeError: The size of tensor a (3) must match the size of tensor b (1536) at non-singleton dimension 1 。此时如果加上prompts.unsqueeze(1)就不会报错了

所以,是否是代码历史版本的问题?有无最新版的可行代码呀?

lichen14 avatar Jul 10 '24 08:07 lichen14

https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277

我在跑这行代码的时候总是会遇到embeddings和prompts特征维度不匹配的问题,导致

https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/models.py#L25

这里无法直接矩阵乘 比如:

  • 计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277

    时,embeddings 和 prompts.unsqueeze(1) 的shape 分别是 torch.Size([8192, 3, 512]) torch.Size([8192, 1, 3, 512]) ,这样直接计算会显存爆炸,需要384G ** 我试过把prompts.unsqueeze(1) 去掉unsqueeze(1) ,这样不报显存OOM的bug了,但是下面的代码会报错:

  • 计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L283

    时,embeddings 和 prompts的shape是torch.Size([1536, 3, 512]) torch.Size([1536, 512]),这时候就会有 RuntimeError: The size of tensor a (3) must match the size of tensor b (1536) at non-singleton dimension 1 。此时如果加上prompts.unsqueeze(1)就不会报错了

所以,是否是代码历史版本的问题?有无最新版的可行代码呀?

你好,请问问题解决了吗?

gs-max avatar Nov 11 '24 07:11 gs-max

这个问题我测试过了,用我上传的处理好的数据没有问题,应该是做数据这一部分我在整理的时候shape不一致导致的。 解决的方法也很简单:https://github.com/Hzzone/PseCo/blob/aab6cdb4a5112cce563df5536ead28d9bb43aae1/models.py#L27

加上一两行就好了。

Hzzone avatar Nov 12 '24 17:11 Hzzone

这个问题我测试过了,用我上传的处理好的数据没有问题,应该是做数据这一部分我在整理的时候shape不一致导致的。 解决的方法也很简单:

https://github.com/Hzzone/PseCo/blob/aab6cdb4a5112cce563df5536ead28d9bb43aae1/models.py#L27

加上一两行就好了。

您好您好,大佬,我没有找到有效的pth文件啊,期待回复一下,好像是谷歌存储库过期了?

ld-xy avatar Dec 25 '24 11:12 ld-xy