CapsNet-pytorch
CapsNet-pytorch copied to clipboard
squash
I have a question: according to the paper, the squash function only be used after the sum of prediction u-hat? and in this code, there is a squash after the primary capsule. I got really confused.
class PrimaryCapsLayer(nn.Module): def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride): super(PrimaryCapsLayer, self).__init__() self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride) # input_channels = 256,output_caps = 32, output_dim = 8, kernel_size = 9, stride = 2 self.input_channels = input_channels self.output_caps = output_caps self.output_dim = output_dim def forward(self, input): out = self.conv(input) N, C, H, W = out.size() out = out.view(N, self.output_caps, self.output_dim, H, W) # will output N x OUT_CAPS x OUT_DIM out = out.permute(0, 1, 3, 4, 2).contiguous() out = out.view(out.size(0), -1, out.size(4)) out = squash(out) #####QUESTION?? return out