cls_head 计算特征时维度不匹配?
https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277
我在跑这行代码的时候总是会遇到embeddings和prompts特征维度不匹配的问题,导致 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/models.py#L25 这里无法直接矩阵乘
比如:
- 计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277 时,embeddings 和 prompts.unsqueeze(1) 的shape 分别是 torch.Size([8192, 3, 512]) torch.Size([8192, 1, 3, 512]) ,这样直接计算会显存爆炸,需要384G ** 我试过把prompts.unsqueeze(1) 去掉unsqueeze(1) ,这样不报显存OOM的bug了,但是下面的代码会报错:
- 计算
https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L283
时,embeddings 和 prompts的shape是torch.Size([1536, 3, 512]) torch.Size([1536, 512]),这时候就会有
RuntimeError: The size of tensor a (3) must match the size of tensor b (1536) at non-singleton dimension 1。此时如果加上prompts.unsqueeze(1)就不会报错了
所以,是否是代码历史版本的问题?有无最新版的可行代码呀?
https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277
我在跑这行代码的时候总是会遇到embeddings和prompts特征维度不匹配的问题,导致
https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/models.py#L25
这里无法直接矩阵乘 比如:
计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L277
时,embeddings 和 prompts.unsqueeze(1) 的shape 分别是 torch.Size([8192, 3, 512]) torch.Size([8192, 1, 3, 512]) ,这样直接计算会显存爆炸,需要384G ** 我试过把prompts.unsqueeze(1) 去掉unsqueeze(1) ,这样不报显存OOM的bug了,但是下面的代码会报错:
计算 https://github.com/Hzzone/PseCo/blob/0192ee521cd3786ed5f2d89231fd74567ef87d9c/fsc147/4_1_train_roi_head.py#L283
时,embeddings 和 prompts的shape是torch.Size([1536, 3, 512]) torch.Size([1536, 512]),这时候就会有
RuntimeError: The size of tensor a (3) must match the size of tensor b (1536) at non-singleton dimension 1。此时如果加上prompts.unsqueeze(1)就不会报错了所以,是否是代码历史版本的问题?有无最新版的可行代码呀?
你好,请问问题解决了吗?
这个问题我测试过了,用我上传的处理好的数据没有问题,应该是做数据这一部分我在整理的时候shape不一致导致的。 解决的方法也很简单:https://github.com/Hzzone/PseCo/blob/aab6cdb4a5112cce563df5536ead28d9bb43aae1/models.py#L27
加上一两行就好了。
这个问题我测试过了,用我上传的处理好的数据没有问题,应该是做数据这一部分我在整理的时候shape不一致导致的。 解决的方法也很简单:
https://github.com/Hzzone/PseCo/blob/aab6cdb4a5112cce563df5536ead28d9bb43aae1/models.py#L27
加上一两行就好了。
您好您好,大佬,我没有找到有效的pth文件啊,期待回复一下,好像是谷歌存储库过期了?