fpn.pytorch
fpn.pytorch copied to clipboard
inconsistant anchor reference (may be the cause of the lower accuracy than faster-rcnn)
Hi,
Please notice you have an inconsistant reference to the order of the anchors (lines 86 and 79)
https://github.com/jwyang/fpn.pytorch/blob/23bd1d2fa09fbb9453f11625d758a61b9d600942/lib/model/rpn/rpn_fpn.py#L86
https://github.com/jwyang/fpn.pytorch/blob/23bd1d2fa09fbb9453f11625d758a61b9d600942/lib/model/rpn/rpn_fpn.py#L79
lets say k is the number of anchors
than in line 79 you are doing softmax after reshape which make anchors: 0:k-1 assosiated with proposal = false k:2k-1 assosiated with proposal = True.
and in line 86 your anchors are arranged in a different way: i mod 2 == 0 associated with proposal = false i mod 2 == 1 associated with proposal = True.
I propose to reshape the score in this way to be consistant. (You will also need to change few other things in proposalLayer for it to work)
rpn_cls_scores.append(rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2))
Hi, @doronpor , great thanks to point out this! I also doubted the reason for the worse performance is on this part but I did have time to check my implementation. Have you ever tried to modified it and trained the model?
Hi Jwyang.
I have similar code with changes that has been trained and working. I can try and help with the modifications in your code.
@doronpor What's your new mAP on Pascal VOC 07? Also, you mention above that "you will also need to change a few other things in the proposalLayer..." I don't see this. To me, it looks like you only need to change rpn_cls_score
-->rpn_cls_score_reshape
on line 86.
Hi, @doronpor, the same question, :), what is your mAP on VOC07? I will add you as a collaborator, and that would be great if you can help with that, :)
I think you also have to change rpn_cls_prob
on line 87 to rpn_cls_prob_reshape
. Then, the output tensor after the permute
and view
will correspond to softmax probs in the form:
[[bg,fg]
[bg,fg]]
Then you won't have to change anything in the proposal layer since you select the fg scores score[:,:,1]
here.
@Feynman27 Yes, you are right, I did that and I am evaluating the model.
@jwyang @feynman27 I have not trained on VOC07 so I can't tell you.
The additional issue can be that rpn_bbox_pred after the view operation enumerates the anchors by (anchor_ratio X width X hight) and because of the change we did the rpn_cls_score has different enumaration, i.e. (width X hight X anchor_ratio )
The problem I stated before is that proposal_layer assumes they have the same order. So you may still have an issue.
@doronpor The rpn_bbox_pred
dims are permuted here
`rpn_bbox_preds.append(rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4))`
This makes the order (batchsize, H,W, A).
@Feynman27
the issue is not the mismatch in dimation, rather the order of the anchor, width, height after all the reshape operation
for rpn_cls_score_reshape before the reorder we had dim (batchsize, 2, H*A/2, W ). The reshape and view operation to this tensor
rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2))
after the reshape if you look at the enumaration order in dim=1, i.e. rpn_cls_scores[0, :, 1], you have (width X hight X anchor_ratio )
in rpn_bbox_preds we dont reshape the tensor first which means it has dim of (batchsize, A*4, H, W ) The reshape and view operation to this tensor
rpn_bbox_preds .permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4))
after the reshape if you look at the enumaration order in dim=1, i.e. rpn_bbox_preds[0, :, 0], you have (anchor_ratio X width X hight )
This is becuase in one plase we use reshaped tensor (rpn_cls_score_reshape) and in the other (rpn_bbox_preds) not.
So maybe we could match the order by doing something like:
rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)
# reshape to match order of anchors
rpn_bbox_pred_reshape = rpn_bbox_pred.view(rpn_bbox_pred.size(0),
rpn_bbox_pred.size(1)/4,
4,
rpn_bbox_pred.size(2),
rpn_bbox_pred.size(3))
rpn_bbox_pred_reshape = rpn_bbox_pred_reshape.permute(0,1,3,4,2).contiguous().view(batch_size, -1, 4)
This would have an enumeration order of (Width x Height x Anchor)
I'm checking now to see if this is the same order by which the anchors are generated.
So I don't believe that rpn_bbox_pred
needs to be reshaped. If you print out line 138 of the anchor generation code, the anchors are enumerated according to (Anchor x Width x Height), thus the rpn_bbox_pred
shape should also be enumerated in this order (Anchor x Width x Height). So I think you should only use the reshaped rpn_cls_score
and rpn_cls_prob
, but the rpn_bbox_pred
is in the correct order and aligns with the anchors.
@Feynman27 You are correct, anchors in proposal_layer and rpn_bbox_pred are enumarated as you say by (Anchors X width X height ) but rpn_cls_score is enumarated by (Width X height X anchors) because now we take the reshaped score .
I changes both rpn_bbox_pred as you suggested using self.reshape(rpn_bbox_pred,4) and the anchor creation inside proposal layer so all would match ( width X height X anchor)
Okay, I see here that the indices of the scores/probs must exactly match the indices of the bbox_deltas. I've changed the anchor generation minimally such that the enumeration order is (Width x Height x Anchor):
# Reshape to get a list of (x, y) and a list of (w, h)
# Enumerate by (Width x Height x Anchor)
box_centers = np.stack([box_centers_x.transpose(1,0).reshape(-1),
box_centers_y.transpose(1,0).reshape(-1)], axis=1)
box_sizes = np.stack([box_widths.transpose(1,0).reshape(-1),
box_heights.transpose(1,0).reshape(-1)], axis=1)
# Convert to corner coordinates (x1, y1, x2, y2)
boxes = np.concatenate([box_centers - 0.5 * box_sizes,
box_centers + 0.5 * box_sizes], axis=1)
@jwyang @doronpor @Feynman27 Hi, have you guys fix the bug? Thanks.
@jwyang @doronpor Hi,if any of you already solve this problem in your own code, could you please kindly modify the code on this github? I've got some trouble to solve it on my own. Thank you very much!
@Feynman27 @doronpor Hi,and I have an additional question if you may allow me to ask. It seems that you guys assume that rpn_cls_score and rpn_bbox_pred have a exactly arranged order as [batch_size, anchor_ratio(A), width(W), height(H)]. So after reshaping rpn_bbox_pred to fix the bug, you also talked about changing anchor generation layer to keep the enumeration order the same at everywhere. However, rpn_cls_score and rpn_bbox_pred are figured out from two different Conv layers. So how could the results of Conv layers have pre-arranged order? Couldn't it be any order and we left it to network itself to learn? I mean anyway before they are feed into self.RPN_proposal, they are viewed to [batch_size, -1, 2] and [batch_size, -1, 4]. At dim1=-1 so they just multiplied all together. And I followed your disscusion(including changing anchor generation layer) and run VOC2007 get 74.08 mAP only. Maybe I did something wrong so I really really need your help.
@qq184861643 I also didn't observe much of a change in the mAP on the dataset that I trained on. This was surprising. So it seems this bug doesn't make much of a difference. This is why I haven't made a pull request.
I also trained on my dataset using Detectron and only observed ~1.5% [email protected] increase. I also used a larger input scale (800) when I trained using Detectron.
@jwyang ,I met the following issue,can you offer me some solution?Thanks a lot.
Traceback (most recent call last):
File "/media/csu/新加卷/AI Competition/Competition/2018广东工业智造大数据创新大赛/intermediary_contest/code/fpn.pytorch-master/trainval_net.py", line 313, in
@Feynman27 ,I met the following issue,can you offer me some solution?Thanks a lot. Traceback (most recent call last): File "/media/csu/新加卷/AI Competition/Competition/2018广东工业智造大数据创新大赛/intermediary_contest/code/fpn.pytorch-master/trainval_net.py", line 313, in roi_labels = FPN(im_data, im_info, gt_boxes, num_boxes) File "/home/csu/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in call result = self.forward(*input, **kwargs) File "/media/csu/新加卷/AI Competition/Competition/2018广东工业智造大数据创新大赛/intermediary_contest/code/fpn.pytorch-master/lib/model/fpn/fpn.py", line 251, in forward bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4)) RuntimeError: invalid argument 2: Input tensor must have same size as output tensor apart from the specified dimension at /pytorch/aten/src/THC/generic/THCTensorScatterGather.cu:29
@1csu I am facing the same issue. Did you figure out how to solve it? @Feynman27 @jwyang
I am facing the same issue. Did you figure out how to solve it? @1csu @Karthik-Suresh93