Capsule
Capsule copied to clipboard
使用keras2.2.4运行时遇到一个问题
b = K.batch_dot(o, u_hat_vecs, [2, 3]) ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 16) and (None, 10, None, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 16)
而且b = K.batch_dot(o, u_hat_vecs, [2, 3]) 中b不需要和上一轮的b进行叠加吗
从tf1.14之后,batch_dot的预期行为有所改变,本来得到的o应该是(None, 10, 16)的。如果你使用tensorflow,将迭代流程修改为如下即可:
for i in range(self.routings):
c = softmax(b, 1) # shape = [None, num_capsule, input_num_capsule]
# o = K.batch_dot(c, u_hat_vecs, [2, 2])
o = tf.einsum('bin,binj->bij', c, u_hat_vecs)
if i < self.routings - 1:
o = K.l2_normalize(o, -1)
# b = K.batch_dot(o, u_hat_vecs, [2, 3])
b = tf.einsum('bij,binj->bin', o, u_hat_vecs)
另外,我尝试了将b与上一轮结果叠加,效果反而变差了。
b = tf.einsum('bij,binj->bin', o, u_hat_vecs)
正解。感谢帮助他人解惑。