Capsule icon indicating copy to clipboard operation
Capsule copied to clipboard

使用keras2.2.4运行时遇到一个问题

Open ghost opened this issue 6 years ago • 2 comments

b = K.batch_dot(o, u_hat_vecs, [2, 3]) ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 16) and (None, 10, None, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 16)

而且b = K.batch_dot(o, u_hat_vecs, [2, 3]) 中b不需要和上一轮的b进行叠加吗

ghost avatar Oct 31 '19 10:10 ghost

从tf1.14之后,batch_dot的预期行为有所改变,本来得到的o应该是(None, 10, 16)的。如果你使用tensorflow,将迭代流程修改为如下即可:

for i in range(self.routings):
    c = softmax(b, 1)  # shape = [None, num_capsule, input_num_capsule]
    # o = K.batch_dot(c, u_hat_vecs, [2, 2])
    o = tf.einsum('bin,binj->bij', c, u_hat_vecs)
    if i < self.routings - 1:
        o = K.l2_normalize(o, -1)
        # b = K.batch_dot(o, u_hat_vecs, [2, 3])
        b = tf.einsum('bij,binj->bin', o, u_hat_vecs)

另外,我尝试了将b与上一轮结果叠加,效果反而变差了。

zhen8838 avatar Dec 28 '19 04:12 zhen8838

b = tf.einsum('bij,binj->bin', o, u_hat_vecs)

正解。感谢帮助他人解惑。

bojone avatar Dec 28 '19 04:12 bojone