stn.pytorch
stn.pytorch copied to clipboard
BCHW format
Excellent work!
I would like to use this in the middle of my pytorch network, so my tensors are in [Batch x Channel x Height x Width] format. I tried to use torch.permute to change their dimension orders, but it was not successful. For example, when a = torch.randn((2,3,4,5)), a.stride() is (60, 20, 5, 1), but if I do b = a.permute((0,1,2,3)), b.stride() is (1, 60, 20, 5) while torch.randn(5,2,3,4).stride() is (24, 12, 4, 1).
Is there an easy and efficient way to do it? or do I need to change .c and .cu files in src?
I guess a.permute((0,1,2,3)).contiguous() might be a solution, but I'm not sure it is safe for Variable (autograd).
Thank you.
You can use transpose: img = img.transpose(1,2).transpose(2,3)
, this should change BCHW layout to BHWC
But, transpose(1,2).transpose(2,3) seems not to rearrange the internal array. torch.FloatTensor(1,2,3,4).stride() and torch.FloatTensor(1,2,3,4).transpose(1,2).transpose(2,3).stride() are (24, 12, 4, 1) and (24, 4, 1, 12), respectively, while torch.FloatTensor(1,3,4,2).stride() is (24, 8, 2, 1).
So if I run the code, at line 44 and 45 in my_lib.c,
real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];
xf is not valid, because grids_strideWidth is still 1.
I guess it needs to be like
real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1*grids_strideChannel];
although I have not tested it.
transpose(1,2).transpose(2,3)
changes the internal array, you can use .size()
to check, I have been using that all the time. test_conv_stn.ipynb
actually uses that fyi.
On a separate note, I guess BCHW should be the standard because it follows pytorch conv layers convention. I probably will have a version for that later. Let me know what you think.
Oh sorry I misunderstood, you are talking about permutation for grid rather than image. Hmm, I always use the grid generator to generate grid in BHWC format directly so never run into the problem you mentioned.
Thank you. Meanwhile I'll use permute (or transpose) and then contiguous(). It seems to work properly so far :)
Thank Fei for the nice work. Do you have any update on BCHW support?
@junyanz Hi Junyan, thank you for your interest, it is likely to be added after the NIPS deadline. We do find the majority of users need BCHW instead of BHWC and will thus prioritize it :D .
@fxia22 Thanks for your prompt response. Good luck with your NIPS submission.
BCHW support added. example can be found in test.py
Thanks a lot!
@fxia22 to go from BHWC
to BCHW
just use img.permute(0, 3, 1, 2)
@edgarriba Thanks for your suggestion. As discussed above, the problem of BCHW
for STN is that BCHW layout is not suitable for coalescing. Permutation itself doesn't change the memory layout, but .contiguous()
after permute will work.
ah, right. Permute just recompute strides