Matrix-Capsules-EM-PyTorch
Matrix-Capsules-EM-PyTorch copied to clipboard
class-capsule 层是否有一个错误?关于权重共享
这个 地方如果是class_caps 类型,这个地方的地方是否应该修改 为
修改代码
if w_shared:
hw = int(B / w.size(1))
w = w.repeat(1, hw, 1, 1, 1)
else:
w = w.repeat(b, 1, 1, 1, 1)
原代码
def transform_view(self, x, w, C, P, w_shared=False):
"""
For conv_caps:
Input: (b*H*W, K*K*B, P*P)
Output: (b*H*W, K*K*B, C, P*P)
For class_caps:
Input: (b, H*W*B, P*P)
Output: (b, H*W*B, C, P*P)
"""
b, B, psize = x.shape
assert psize == P*P
x = x.view(b, B, 1, P, P)
if w_shared:
hw = int(B / w.size(1))
w = w.repeat(1, hw, 1, 1, 1)
w = w.repeat(b, 1, 1, 1, 1)
x = x.repeat(1, 1, C, 1, 1)
v = torch.matmul(x, w)
v = v.view(b, B, C, P*P)
return v
理解了,没有错误,但是权重如果不放在 conv caps 类的初始化里面 ,可能更好理解一点。但是这样写更紧凑