MojiTalk icon indicating copy to clipboard operation
MojiTalk copied to clipboard

Is there a PyTorch version of this code?

Open KingS770234358 opened this issue 5 years ago • 0 comments

I'm trying to implement this paper in pytorch,but I don't know how to convert this part ` with tf.variable_scope("decoder_train") as decoder_scope: if decoder_layer == 2: train_decoder_init_state = ( tf.concat([self.z_sample, ori_encoder_state[0], emoji_vec], axis=1), tf.concat([self.z_sample, ori_encoder_state[1], emoji_vec], axis=1) ) dim = latent_dim + num_unit + emoji_dim cell = tf.nn.rnn_cell.MultiRNNCell( [create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout), create_rnn_cell(dim, 3, cell_type, num_gpu, self.dropout)]) else: train_decoder_init_state = tf.concat([self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1) dim = latent_dim + 2 * num_unit + emoji_dim cell = create_rnn_cell(dim, 2, cell_type, num_gpu, self.dropout)

        with tf.variable_scope("attention"):
            memory = tf.concat([ori_encoder_output[0], ori_encoder_output[1]], axis=2)
            memory = tf.transpose(memory, [1, 0, 2])

            attention_mechanism = seq2seq.LuongAttention(
                dim, memory, memory_sequence_length=self.ori_len, scale=True)
            # attention_mechanism = seq2seq.BahdanauAttention(
            #     num_unit, memory, memory_sequence_length=self.ori_len)

        decoder_cell = seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=dim) # TODO: add_name; what atten layer size means
        # decoder_cell = cell

        helper = seq2seq.TrainingHelper(
            rep_input_emb, self.rep_len + 1, time_major=True)

        projection_layer = layers_core.Dense(
            vocab_size, use_bias=False, name="output_projection")
        decoder = seq2seq.BasicDecoder(
            decoder_cell, helper,
            decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=train_decoder_init_state),
            output_layer=projection_layer)
        train_outputs, _, _ = seq2seq.dynamic_decode(
            decoder,
            output_time_major=True,
            swap_memory=True,
            scope=decoder_scope
        )
        self.logits = train_outputs.rnn_output

    with tf.variable_scope("decoder_infer") as decoder_scope:
        # normal_sample = tf.random_normal(shape=(batch_size, latent_dim))

        if decoder_layer == 2:
            infer_decoder_init_state = (
                tf.concat([self.q_z_sample, ori_encoder_state[0], emoji_vec], axis=1),
                tf.concat([self.q_z_sample, ori_encoder_state[1], emoji_vec], axis=1)
            )
        else:
            infer_decoder_init_state = tf.concat([self.q_z_sample, ori_encoder_state_flat, emoji_vec], axis=1)

        start_tokens = tf.fill([batch_size], start_i)
        end_token = end_i

        if beam_width > 0:
            infer_decoder_init_state = seq2seq.tile_batch(
                infer_decoder_init_state, multiplier=beam_width)
            decoder = seq2seq.BeamSearchDecoder(
                cell=decoder_cell,
                embedding=embedding.coder,
                start_tokens=start_tokens,
                end_token=end_token,
                initial_state=decoder_cell.zero_state(
                    batch_size * beam_width, tf.float32).clone(cell_state=infer_decoder_init_state),
                beam_width=beam_width,
                output_layer=projection_layer,
                length_penalty_weight=0.0)
        else:
            helper = seq2seq.GreedyEmbeddingHelper(
                embedding.coder, start_tokens, end_token)
            decoder = seq2seq.BasicDecoder(
                decoder_cell,
                helper,
                decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=infer_decoder_init_state),
                output_layer=projection_layer  # applied per timestep
            )

        # Dynamic decoding
        infer_outputs, _, infer_lengths = seq2seq.dynamic_decode(
            decoder,
            maximum_iterations=maximum_iterations,
            output_time_major=True,
            swap_memory=True,
            scope=decoder_scope
        )
        if beam_width > 0:
            self.result = infer_outputs.predicted_ids
        else:
            self.result = infer_outputs.sample_id
            self.result_lengths = infer_lengths



    with tf.variable_scope("loss"):
        max_time = tf.shape(self.rep_output)[0]
        with tf.variable_scope("reconstruction"):
            # TODO: use inference decoder's logits to compute recon_loss
            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                labels=self.rep_output, logits=self.logits)
            # rep: [len, batch_size]; logits: [len, batch_size, vocab_size]
            target_mask = tf.sequence_mask(
                self.rep_len + 1, max_time, dtype=self.logits.dtype)
            # time_major
            target_mask_t = tf.transpose(target_mask)  # max_len batch_size
            self.recon_losses = tf.reduce_sum(cross_entropy * target_mask_t, axis=0)
            self.recon_loss = tf.reduce_sum(cross_entropy * target_mask_t) / batch_size

        with tf.variable_scope("latent"):
            # without prior network
            # self.kl_loss = 0.5 * tf.reduce_sum(tf.exp(self.log_var) + self.mu ** 2 - 1. - self.log_var, 0)
            self.kl_losses = 0.5 * tf.reduce_sum(
                tf.exp(self.log_var - self.p_log_var) +
                (self.mu - self.p_mu) ** 2 / tf.exp(self.p_log_var) - 1. - self.log_var + self.p_log_var,
                axis=1)
            self.kl_loss = tf.reduce_mean(self.kl_losses)

        with tf.variable_scope("bow"):
            # self.bow_loss = self.kl_weight * 0
            mlp_b = layers_core.Dense(
                vocab_size, use_bias=False, name="MLP_b")
            # is it a mistake that we only model on latent variable?
            latent_logits = mlp_b(tf.concat(
                [self.z_sample, ori_encoder_state_flat, emoji_vec], axis=1))  # [batch_size, vocab_size]
            latent_logits = tf.expand_dims(latent_logits, 0)  # [1, batch_size, vocab_size]
            latent_logits = tf.tile(latent_logits, [max_time, 1, 1])  # [max_time, batch_size, vocab_size]

            cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(  # ce = [len, batch_size]
                labels=self.rep_output, logits=latent_logits)
            self.bow_losses = tf.reduce_sum(cross_entropy * target_mask_t, axis=0)
            self.bow_loss = tf.reduce_sum(cross_entropy * target_mask_t) / batch_size

`

KingS770234358 avatar Feb 24 '20 09:02 KingS770234358