darts
darts copied to clipboard
A bug on FactorizedReduce in operations.py?
I wonder whether there is a bug in this line in the FactorizedReduce
function:
out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
NOTE: in above line, the input for conv_2
is a sliced one of x
Therefore, it is definite that in some cases the outputs of conv_1
and conv_2
would have different shapes in the 3rd and 4th dimension, which causes error in torch.cat
since the concatenated tensors should have exactly the same shape.
I'm surprised that it seems no one ever came into this problem.
@Dav-Jay , I came cross this bug when I used the code for a customized data set. My patch for the bug is as follows:
def forward(self, x):
x = self.relu(x)
even_conv = self.conv_1(x)
odd_conv = self.conv_2(x[:,:,1:,1:])
diff2 = even_conv.shape[2] - odd_conv.shape[2]
diff3 = even_conv.shape[3] - odd_conv.shape[3]
if diff2 > 0 :
odd_conv = odd_conv.permute(2,1,0,3)
odd_conv = torch.cat((odd_conv,torch.zeros(diff2, odd_conv.shape[1], odd_conv.shape[2],odd_conv.shape[3], dtype=odd_conv.dtype, device=odd_conv.device)), dim=0)
odd_conv = odd_conv.permute(2,1,0,3)
if diff3 > 0:
odd_conv = odd_conv.permute(3,1,2,0)
odd_conv = torch.cat((odd_conv,torch.zeros(diff3, odd_conv.shape[1], odd_conv.shape[2],odd_conv.shape[3], dtype=odd_conv.dtype, device=odd_conv.device)), dim=0)
odd_conv = odd_conv.permute(3,1,2,0)
if diff2 < 0:
even_conv = even_conv.permute(2,1,0,3)
even_conv = torch.cat((even_conv,torch.zeros(diff2, even_conv.shape[1], even_conv.shape[2],even_conv.shape[3], dtype=even_conv.dtype, device=even_conv.device)), dim=0)
even_conv = even_conv.permute(2,1,0,3)
if diff3 < 0:
even_conv = even_conv.permute(3,1,2,0)
even_conv = torch.cat((even_conv,torch.zeros(diff3, even_conv.shape[1], even_conv.shape[2],even_conv.shape[3], dtype=even_conv.dtype, device=even_conv.device)), dim=0)
even_conv = even_conv.permute(3,1,2,0)
out = torch.cat([even_conv, odd_conv], dim=1)
out = self.bn(out)
return out
@wangfrombupt Thanks for mentioning this bug which I did not noticed before. I guess the original codes assumes the feature map (x) always has even width and height. I would fix it by: out = torch.cat([self.conv_1(x[:,:,:-1,:-1]), self.conv_2(x[:,:,1:,1:])], dim=1)
When the width and height are even, it works fine. When the width or height is odd (can be avoid by control the input), there is a mismatch.