Trans2Seg icon indicating copy to clipboard operation
Trans2Seg copied to clipboard

Reshape mis-alignment

Open prismformore opened this issue 4 years ago • 7 comments

https://github.com/xieenze/Trans2Seg/blob/a1591849ced066d32171a25cec01a25fcacccf2b/segmentron/modules/transformer.py#L92

Maybe here you want to swap dim 1 (num_heads) and dim 2 (num_class) before reshaping? Or it is my misunderstanding?

Thank you very much.

prismformore avatar Feb 09 '21 11:02 prismformore

I don't think this code has any problem, at least it train and test well. I have no idea why need to swap dim1 and dim2.

xieenze avatar Feb 09 '21 11:02 xieenze

I don't think this code has any problem, at least it train and test well. I have no idea why need to swap dim1 and dim2.

x is with shape [B, heads, n_class, embed_dim/nheads], and I thought you may want to combine the dims of "heads" and "embed_dim/nheads" to "embed_dim". In this case, we may need to permute x to [B, n_class, heads, embed_dim/nheads] and then reshape x into [B, n_class, embed_dim].

Thank you for your timely reply and please correct me if I am wrong.

prismformore avatar Feb 09 '21 11:02 prismformore

I don't think this code has any problem, at least it train and test well. I have no idea why need to swap dim1 and dim2.

x is with shape [B, heads, n_class, embed_dim/nheads], and I thought you may want to combine the dims of "heads" and "embed_dim/nheads" to "embed_dim". In this case, we may need to permute x to [B, n_class, heads, embed_dim/nheads] and then reshape x into [B, n_class, embed_dim].

Thank you for your timely reply and please correct me if I am wrong.

No, we do not need to combine 'n_head' and 'embed_dim/n_head' here. We used 'n_head' here. https://github.com/xieenze/Trans2Seg/blob/a1591849ced066d32171a25cec01a25fcacccf2b/segmentron/models/trans2seg.py#L105

xieenze avatar Feb 09 '21 12:02 xieenze

@xieenze

Please let me make it clear. In this perticular line: https://github.com/xieenze/Trans2Seg/blob/a1591849ced066d32171a25cec01a25fcacccf2b/segmentron/modules/transformer.py#L92 "(attn3 @ v)" is with shape [B, nheads, n_class, C/nheads], and then in this line it reshapes "(attn3 @ v)" to [B, n_class, C] by combining "nheads" and "C/nheads", which are not neighbors.

In this case, we may need to permute [B, nheads, n_class, C/nheads] to [B, n_class, nheads, C/nheads] beforehand and then reshape it into [B, n_class, C].

prismformore avatar Feb 09 '21 12:02 prismformore

@xieenze

Please let me make it clear. In this perticular line:

https://github.com/xieenze/Trans2Seg/blob/a1591849ced066d32171a25cec01a25fcacccf2b/segmentron/modules/transformer.py#L92

"(attn3 @ v)" is with shape [B, nheads, n_class, C/nheads], and then in this line it reshapes "(attn3 @ v)" to [B, n_class, C] by combining "nheads" and "C/nheads", which are not neighbors. In this case, we may need to permute [B, nheads, n_class, C/nheads] to [B, n_class, nheads, C/nheads] beforehand and then reshape it into [B, n_class, C].

OK, I know what you mean. It may be a small typo, and I don't know what this would influence the final result. I will try to fix this typo when I have time. Thanks for the advice.

xieenze avatar Feb 09 '21 13:02 xieenze

@xieenze The code works fine for me too. This is an interesting observation. Thank you again for your help!

prismformore avatar Feb 09 '21 13:02 prismformore

@prismformore Thanks for interested in our work! 👍

xieenze avatar Feb 09 '21 15:02 xieenze