Involution icon indicating copy to clipboard operation
Involution copied to clipboard

Question: there may be something wrong?

Open crs904620522 opened this issue 2 years ago • 1 comments

thanks for your contribution! Here, for some reason, i need to realize the "involution2D,3D" by myself, and I take this project for validation. However, my results can not be the same as yours. In the begining, i think it may be my fault, but after check i am not sure!!! So could you help me? Here is my question: 1、I think the “Tensor.unfold()" use in "involution.py" are not right........( may be ). Here is the code ( with problems): ‘’‘ input_unfolded = self.pad(input_initial)
.unfold(dimension=2, size=self.kernel_size[0], step=self.stride[0])
.unfold(dimension=3, size=self.kernel_size[1], step=self.stride[1])
.unfold(dimension=4, size=self.kernel_size[2], step=self.stride[2]) input_unfolded = input_unfolded.reshape(batch_size, self.groups, self.out_channels // self.groups, self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], -1) input_unfolded = input_unfolded.reshape(tuple(input_unfolded.shape[:-1]) + (out_depth, out_height, out_width)) ’‘’

In officials, they use "nn.Unfold()" and this is right. the Tensor.unfold() returns ”B,C,H,W,K,K“, and the "nn.Unfold()" returns "B,CxKxK,HxW". So I think the " permute" needed be used if use ”Tensor.unfold()“. And I give an example for comparsion: ################The Code:##############

def nnUnfold_Tensorunfold(): input = torch.ones((1, 1, 5, 5)) # ----------------nnUnfold----------------- # Unfold1 = nn.Unfold(3, 1, (3 - 1) // 2, 1) input_unfolded = Unfold1(input) #====>B,CxKxK,HxW input_unfolded = input_unfolded.contiguous().view(1,9,5,5) print("Official: nn.Unfold():",input_unfolded) # ---------------Tensorunfold--------------- # pad = nn.ConstantPad2d(padding=(1, 1,1, 1), value=0.) input = pad(input) input_unfolded = input input_unfolded = input_unfolded.unfold(dimension=2, size=3, step=1) input_unfolded = input_unfolded.unfold(dimension=3, size=3, step=1) #===>B,C,H,W,K,K before = input_unfolded.contiguous().view(1,9,5,5) print("Wrong: Tensor.unfold():",before) after = input_unfolded.permute(0,1,4,5,2,3).contiguous().view(1,9,5,5) #====> permute should be used print("Right: after permute:",after) # --------------------------------- # if name == 'main': nnUnfold_Tensorunfold()

################The Results:############## Official: nn.Unfold(): tensor([[[[0., 0., 0., 0., 0.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.]],

     [[0., 0., 0., 0., 0.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[0., 0., 0., 0., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.]],

     [[0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.]],

     [[0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 0., 0., 0., 0.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [0., 0., 0., 0., 0.]],

     [[1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [0., 0., 0., 0., 0.]]]])

Wrong: Tensor.unfold(): tensor([[[[0., 0., 0., 0., 1.], [1., 0., 1., 1., 0.], [0., 0., 1., 1., 1.], [1., 1., 1., 0., 0.], [0., 1., 1., 1., 1.]],

     [[1., 1., 0., 0., 0.],
      [1., 1., 1., 1., 1.],
      [1., 0., 0., 0., 1.],
      [1., 0., 1., 1., 0.],
      [0., 1., 1., 0., 1.]],

     [[1., 0., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 0., 1.],
      [1., 0., 1., 1., 0.],
      [0., 1., 1., 0., 1.],
      [1., 0., 1., 1., 1.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[1., 1., 1., 0., 1.],
      [1., 0., 1., 1., 0.],
      [0., 1., 1., 0., 1.],
      [1., 0., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 0., 1.]],

     [[1., 0., 1., 1., 0.],
      [0., 1., 1., 0., 1.],
      [1., 0., 0., 0., 1.],
      [1., 1., 1., 1., 1.],
      [0., 0., 0., 1., 1.]],

     [[1., 1., 1., 1., 0.],
      [0., 0., 1., 1., 1.],
      [1., 1., 1., 0., 0.],
      [0., 1., 1., 0., 1.],
      [1., 0., 0., 0., 0.]]]])

Right: after permute: tensor([[[[0., 0., 0., 0., 0.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.]],

     [[0., 0., 0., 0., 0.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[0., 0., 0., 0., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.]],

     [[0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.]],

     [[1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.]],

     [[0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 1., 1., 1., 1.],
      [0., 0., 0., 0., 0.]],

     [[1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [1., 1., 1., 1., 1.],
      [0., 0., 0., 0., 0.]],

     [[1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [1., 1., 1., 1., 0.],
      [0., 0., 0., 0., 0.]]]])

######################################## Maybe i am wrong..... could you help me?

crs904620522 avatar May 15 '22 15:05 crs904620522

Here I take the 2D unfold as an example, the true issue is in 3D unfold

crs904620522 avatar May 15 '22 15:05 crs904620522