darts icon indicating copy to clipboard operation
darts copied to clipboard

A bug on FactorizedReduce in operations.py?

Open Dav-Jay opened this issue 5 years ago • 3 comments

I wonder whether there is a bug in this line in the FactorizedReduce function:

out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)

NOTE: in above line, the input for conv_2 is a sliced one of x

Therefore, it is definite that in some cases the outputs of conv_1 and conv_2 would have different shapes in the 3rd and 4th dimension, which causes error in torch.cat since the concatenated tensors should have exactly the same shape.

I'm surprised that it seems no one ever came into this problem.

Dav-Jay avatar Apr 17 '19 11:04 Dav-Jay

@Dav-Jay , I came cross this bug when I used the code for a customized data set. My patch for the bug is as follows:

  def forward(self, x):
    x = self.relu(x)
    even_conv = self.conv_1(x) 
    odd_conv = self.conv_2(x[:,:,1:,1:]) 

    diff2 = even_conv.shape[2] - odd_conv.shape[2] 
    diff3 = even_conv.shape[3] - odd_conv.shape[3] 

    if diff2 > 0 :
      odd_conv = odd_conv.permute(2,1,0,3)
      odd_conv = torch.cat((odd_conv,torch.zeros(diff2, odd_conv.shape[1], odd_conv.shape[2],odd_conv.shape[3], dtype=odd_conv.dtype, device=odd_conv.device)), dim=0)
      odd_conv = odd_conv.permute(2,1,0,3) 
    if diff3 > 0:
      odd_conv = odd_conv.permute(3,1,2,0)
      odd_conv = torch.cat((odd_conv,torch.zeros(diff3, odd_conv.shape[1], odd_conv.shape[2],odd_conv.shape[3], dtype=odd_conv.dtype, device=odd_conv.device)), dim=0)
      odd_conv = odd_conv.permute(3,1,2,0)
    if diff2 < 0:
      even_conv = even_conv.permute(2,1,0,3)
      even_conv = torch.cat((even_conv,torch.zeros(diff2, even_conv.shape[1], even_conv.shape[2],even_conv.shape[3], dtype=even_conv.dtype, device=even_conv.device)), dim=0)
      even_conv = even_conv.permute(2,1,0,3)      
    if diff3 < 0:
      even_conv = even_conv.permute(3,1,2,0)
      even_conv = torch.cat((even_conv,torch.zeros(diff3, even_conv.shape[1], even_conv.shape[2],even_conv.shape[3], dtype=even_conv.dtype, device=even_conv.device)), dim=0)
      even_conv = even_conv.permute(3,1,2,0)      
   
    out = torch.cat([even_conv, odd_conv], dim=1)
    out = self.bn(out)
    return out

wangfrombupt avatar Aug 07 '19 05:08 wangfrombupt

@wangfrombupt Thanks for mentioning this bug which I did not noticed before. I guess the original codes assumes the feature map (x) always has even width and height. I would fix it by: out = torch.cat([self.conv_1(x[:,:,:-1,:-1]), self.conv_2(x[:,:,1:,1:])], dim=1)

bolianchen avatar Dec 19 '19 08:12 bolianchen

When the width and height are even, it works fine. When the width or height is odd (can be avoid by control the input), there is a mismatch.

D-X-Y avatar Jan 17 '20 11:01 D-X-Y