Pytorch-CapsuleNet
Pytorch-CapsuleNet copied to clipboard
squash function seems not turn vector with 16 value to 1 value as vector length
In file capsnet.py => function squash,
def squash(self, input_tensor):
print(f'input_tensor.shape: {input_tensor.shape}')
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
print(f'squared_norm.shape: {squared_norm.shape}')
Print out :
input_tensor.shape: torch.Size([100, 1, 10, 16, 1])
squared_norm.shape: torch.Size([100, 1, 10, 16, 1])
in my understanding, the squared_norm should be the length of vector, where the vector is in 16 dimension (= dim of 3) thus, after operation of **2 and .sum, it should become single number, i expected the output to be in shape [100, 1, 10, 1, 1] however, the code .sum on last dimension, which is incorrect.
do i misunderstand?