few-shot-ssl-public icon indicating copy to clipboard operation
few-shot-ssl-public copied to clipboard

What does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) mean?

Open jinghanSunn opened this issue 4 years ago • 1 comments

In clustering, I don't understand this code:

# Run clustering.
for tt in range(num_cluster_steps):
      protos_1 = tf.expand_dims(protos, 2)
      protos_2 = tf.expand_dims(h_unlabel, 1)
      pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3])  # [B, K, N]
      m_dist = tf.reduce_mean(pair_dist, [2])  # [B, K]
      m_dist_1 = tf.expand_dims(m_dist, 1)  # [B, 1, K]
      m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))

Does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) mean that if the distance from the center of the cluster is 1 then add 1. But why add 1?

jinghanSunn avatar Sep 04 '20 03:09 jinghanSunn

This is to prevent it to be zero. So it will be changed to 1.0 when it's 0.0

renmengye avatar Sep 04 '20 14:09 renmengye