stn.pytorch icon indicating copy to clipboard operation
stn.pytorch copied to clipboard

BCHW format

Open thnkim opened this issue 7 years ago • 13 comments

Excellent work!

I would like to use this in the middle of my pytorch network, so my tensors are in [Batch x Channel x Height x Width] format. I tried to use torch.permute to change their dimension orders, but it was not successful. For example, when a = torch.randn((2,3,4,5)), a.stride() is (60, 20, 5, 1), but if I do b = a.permute((0,1,2,3)), b.stride() is (1, 60, 20, 5) while torch.randn(5,2,3,4).stride() is (24, 12, 4, 1).

Is there an easy and efficient way to do it? or do I need to change .c and .cu files in src?

I guess a.permute((0,1,2,3)).contiguous() might be a solution, but I'm not sure it is safe for Variable (autograd).

Thank you.

thnkim avatar Apr 11 '17 08:04 thnkim

You can use transpose: img = img.transpose(1,2).transpose(2,3), this should change BCHW layout to BHWC

fxia22 avatar Apr 12 '17 07:04 fxia22

But, transpose(1,2).transpose(2,3) seems not to rearrange the internal array. torch.FloatTensor(1,2,3,4).stride() and torch.FloatTensor(1,2,3,4).transpose(1,2).transpose(2,3).stride() are (24, 12, 4, 1) and (24, 4, 1, 12), respectively, while torch.FloatTensor(1,3,4,2).stride() is (24, 8, 2, 1).

So if I run the code, at line 44 and 45 in my_lib.c,

real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];

xf is not valid, because grids_strideWidth is still 1.

I guess it needs to be like

real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1*grids_strideChannel];

although I have not tested it.

thnkim avatar Apr 12 '17 07:04 thnkim

transpose(1,2).transpose(2,3) changes the internal array, you can use .size() to check, I have been using that all the time. test_conv_stn.ipynb actually uses that fyi.

On a separate note, I guess BCHW should be the standard because it follows pytorch conv layers convention. I probably will have a version for that later. Let me know what you think.

fxia22 avatar Apr 12 '17 08:04 fxia22

Oh sorry I misunderstood, you are talking about permutation for grid rather than image. Hmm, I always use the grid generator to generate grid in BHWC format directly so never run into the problem you mentioned.

fxia22 avatar Apr 12 '17 08:04 fxia22

Thank you. Meanwhile I'll use permute (or transpose) and then contiguous(). It seems to work properly so far :)

thnkim avatar Apr 12 '17 08:04 thnkim

Thank Fei for the nice work. Do you have any update on BCHW support?

junyanz avatar May 05 '17 05:05 junyanz

@junyanz Hi Junyan, thank you for your interest, it is likely to be added after the NIPS deadline. We do find the majority of users need BCHW instead of BHWC and will thus prioritize it :D .

fxia22 avatar May 05 '17 05:05 fxia22

@fxia22 Thanks for your prompt response. Good luck with your NIPS submission.

junyanz avatar May 05 '17 05:05 junyanz

BCHW support added. example can be found in test.py

fxia22 avatar Jun 12 '17 23:06 fxia22

Thanks a lot!

junyanz avatar Jun 12 '17 23:06 junyanz

@fxia22 to go from BHWC to BCHW just use img.permute(0, 3, 1, 2)

edgarriba avatar Jun 14 '17 17:06 edgarriba

@edgarriba Thanks for your suggestion. As discussed above, the problem of BCHW for STN is that BCHW layout is not suitable for coalescing. Permutation itself doesn't change the memory layout, but .contiguous() after permute will work.

fxia22 avatar Jun 14 '17 17:06 fxia22

ah, right. Permute just recompute strides

edgarriba avatar Jun 14 '17 20:06 edgarriba