few-shot-ssl-public
few-shot-ssl-public copied to clipboard
What does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) mean?
In clustering, I don't understand this code:
# Run clustering.
for tt in range(num_cluster_steps):
protos_1 = tf.expand_dims(protos, 2)
protos_2 = tf.expand_dims(h_unlabel, 1)
pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N]
m_dist = tf.reduce_mean(pair_dist, [2]) # [B, K]
m_dist_1 = tf.expand_dims(m_dist, 1) # [B, 1, K]
m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))
Does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))
mean that if the distance from the center of the cluster is 1 then add 1.
But why add 1?
This is to prevent it to be zero. So it will be changed to 1.0 when it's 0.0