Char-RNN-TensorFlow icon indicating copy to clipboard operation
Char-RNN-TensorFlow copied to clipboard

请问seq_output = tf.concat(self.lstm_outputs, 1)的用意是什么?

Open Eddiechiu opened this issue 6 years ago • 7 comments

你好,想请教个问题。 我的运行下来,报错在 y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())

因为y_one_hot和self.logits的总元素数量不同,所以不能reshape。 我推算了一下:

  1. inputs的shape是(num_seqs, num_steps),经过tf.one_hot以后,lstm_inputs的shape变成(num_seqs, num_steps, num_classes)

  2. 我用的是cell是一层的lstm,lstm_inputs经过tf.nn.dynamic(cell, lstm_inputs, initial_state=self.initial_state)后,lstm_outputs的shape是(num_seqs, num_steps, lstm_size)

  3. lstm_outputs经过tf.concat(lstm_outputs, 1)以后,shape没有任何变化,再经过一些列运算后,shape就会有问题。

所以想问一下tf.concat(lstm_outputs, 1)这一步是做什么的? 感谢~

Eddiechiu avatar Jul 06 '18 02:07 Eddiechiu

我认为,self.lstm_inputs的shape在经过embedding_lookup后,应该是(num_seqs, num_steps, embedding_size)。也就是一个input由embedding_size大小的向量表示。

charmpeng avatar Jul 12 '18 02:07 charmpeng

嗯,你的embedding_size就是我的num_classes, 但是tf.concat(lstm_outputs, 1)这一步我没懂,而且跳过这一步程序可以正常运行。

Eddiechiu avatar Jul 12 '18 02:07 Eddiechiu

'with tf.name_scope('lstm'):
        cell = tf.nn.rnn_cell.MultiRNNCell(
            [get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)]
        )
        self.initial_state = cell.zero_state(self.num_seqs, tf.float32)

        # 通过dynamic_rnn对cell展开时间维度
        self.lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs, initial_state=self.initial_state)
        print("self.lstm_outputs.get_shape",self.lstm_outputs.get_shape()) # (32,50,128)
        seq_output = tf.concat(self.lstm_outputs, 1)
        print("seq_output.get_shape()",seq_output.get_shape())  # (32,50,128)
        x = tf.reshape(seq_output, [-1, self.lstm_size])
        print("x.get_shape()",x.get_shape())  # (1600,128)`

我打印出了shape,看来seq_output = tf.concat(self.lstm_outputs, 1) 这一句并没什么用处,因为concat前后,lstm_outputs与seq_output的shape都是一样的。

charmpeng avatar Jul 12 '18 05:07 charmpeng

嗯嗯,是的,所以我就直接reshape了 =D

Eddiechiu avatar Jul 12 '18 12:07 Eddiechiu

这一步确实没用。

sunnima avatar Mar 19 '19 10:03 sunnima

确实没用。。。

Natumsol avatar Oct 23 '19 07:10 Natumsol

tf.concat(values, 1)#values必须是序列,在这里不起作用

cherish6092 avatar Sep 22 '21 11:09 cherish6092