waveglow-vqvae
waveglow-vqvae copied to clipboard
some confusions about soft-em
hi, since you waveglow propose to use a soft-em version of vqvae, the core implementation is: " def _square_distance(x, code_book): x = tf.cast(x, tf.float32) code_book = tf.cast(code_book, tf.float32) x_sg = tf.stop_gradient(x) x_norm_sq = tf.reduce_sum(tf.square(x_sg), axis=-1, keepdims=True) # [b, 1] code_book_norm_sq = tf.reduce_sum(tf.square(code_book), axis=-1, keepdims=True) # [V, 1] scalar_prod = tf.matmul(x_sg, code_book, transpose_b=True) # [b, V] dist_sq = x_norm_sq + tf.transpose(code_book_norm_sq) - 2 * scalar_prod # [b, V]
return tf.cast(dist_sq, x.dtype.base_dtype)
dist_sq = _square_distance(x, code_book) q = tf.stop_gradient(tf.nn.softmax(-.5 * dist_sq)) discrete = tf.one_hot(tf.argmax(-dist_sq, axis=-1), depth=bottleneck_size, dtype=code_book.dtype.base_dtype) dense = tf.matmul(discrete, code_book) dense = dense + x - tf.stop_gradient(x) def _get_losses(x, x_mask, dense, dist_sq, q): x = tf.cast(x, tf.float32) x_mask = tf.cast(x_mask, tf.float32) dense = tf.cast(dense, tf.float32) dist_sq = tf.cast(dist_sq, tf.float32) q = tf.cast(q, tf.float32) disc_loss = tf.reduce_sum(tf.reduce_sum(tf.square(x - tf.stop_gradient(dense)), -1)*x_mask) / (1e-10+tf.reduce_sum(x_mask)) # # M-step em_loss = -tf.reduce_sum(tf.reduce_sum(-.5 * dist_sq * q, -1)*x_mask) / (1e-10+tf.reduce_sum(x_mask)) return disc_loss, em_loss disc_loss, em_loss = _get_losses(x, x_mask, dense, dist_sq, q) " , however, the tensor2tensor has a different implementation: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/vq_discrete.py
- multisample to get mean of soft-alignment
- when calculate em-loss, it has a different loss funtion type compare to your "M-step" . Could you hepl me with it?