ELEGANT
ELEGANT copied to clipboard
一個小問題
你把原來的z做concat是為了更新encoder吧? 畢竟swap元素這個操作不可微分。 那你當初在搭這個模型的時候有試過straight through estimatior嗎?
甘溫
順便分享一下我有在嘗試用tensorflow跑這個模型的實驗
swap跟剛剛講的straight through estimator我是這樣做
def get_idx(z, y):
s = [tf.range(z.shape[i]) for i in range(3)]
d1, d2 ,d3 = tf.meshgrid(s[1], s[0], s[2])
idx=tf.stack([d2, d1, d3], axis=-1)
_, h, w, _ = idx.shape
y=tf.repeat(tf.repeat(y[:, None, None], h, axis=1), w, axis=2)[...,None]
idx = tf.concat([idx , y], axis=-1)
return idx
def get_corr_ele(z, y1, y2):
idx1=get_idx(z, y1)
idx2=get_idx(z, y2)
idx = tf.concat([idx1, idx2], axis=0)
ele=tf.gather_nd(z,idx)
return ele, idx
def swap(z1, z2, y1, y2):
z1y, idx=self.get_corr_ele(z1, y1, y2)
z2y, _=self.get_corr_ele(z2, y1, y2)
z12 = tf.tensor_scatter_nd_update(z1, idx, z2y)
#z21 = tf.tensor_scatter_nd_update(z2, idx, x1y)
#straight throguh estimator
z12 = z1 + tf.stop_gradient(z12-z1)
#z21 = z2 + tf.stop_gradient(z21-z2)
z12 = tf.concat([z12, z1], axis=-1)
return z12 #, z21