Matrix-Capsules-EM-PyTorch icon indicating copy to clipboard operation
Matrix-Capsules-EM-PyTorch copied to clipboard

class-capsule 层是否有一个错误?关于权重共享

Open fangxu622 opened this issue 4 years ago • 1 comments

这个 地方如果是class_caps 类型,这个地方的地方是否应该修改 为

修改代码

if w_shared:
     hw = int(B / w.size(1))
     w = w.repeat(1, hw, 1, 1, 1)
else:
    w = w.repeat(b, 1, 1, 1, 1)

原代码

    def transform_view(self, x, w, C, P, w_shared=False):
        """
            For conv_caps:
                Input:     (b*H*W, K*K*B, P*P)
                Output:    (b*H*W, K*K*B, C, P*P)
            For class_caps:
                Input:     (b, H*W*B, P*P)
                Output:    (b, H*W*B, C, P*P)
        """
        b, B, psize = x.shape
        assert psize == P*P

        x = x.view(b, B, 1, P, P)
        if w_shared:
            hw = int(B / w.size(1))
            w = w.repeat(1, hw, 1, 1, 1)

        w = w.repeat(b, 1, 1, 1, 1)
        x = x.repeat(1, 1, C, 1, 1)
        v = torch.matmul(x, w)
        v = v.view(b, B, C, P*P)
        return v

fangxu622 avatar Mar 12 '20 21:03 fangxu622

理解了,没有错误,但是权重如果不放在 conv caps 类的初始化里面 ,可能更好理解一点。但是这样写更紧凑 image

fangxu622 avatar Mar 12 '20 21:03 fangxu622