zhihu icon indicating copy to clipboard operation
zhihu copied to clipboard

anna_lstm 两个版本的都出现问题

Open 1mrliu opened this issue 6 years ago • 4 comments

运行两个版本的LSTM代码的时候,都是出现这个问题,一直没找到这个bug,所以想请教一下。 问题出现的代码位置:

运行RNN

outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=self.initial_state)

提示出现的错误信息:(貌似是Tensor的形状问题,但是检查了代码,没有发现这个问题出在哪里了) Dimensions must be equal, but are 1024 and 595 for 'rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/MatMul_1' (op: 'MatMul') with input shapes: [1,1024], [595,2048].

1mrliu avatar Jul 17 '18 02:07 1mrliu

我的也是,没看出是为什么: ValueError: Dimensions must be equal, but are 1024 and 595 for 'rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/MatMul_1' (op: 'MatMul') with input shapes: [100,1024], [595,2048].

PengInGitHub avatar Feb 02 '19 09:02 PengInGitHub

你好,请问问题解决了吗?我也遇到了这个问题,不知道是为什么

snow123321 avatar Feb 14 '19 08:02 snow123321

按以下代码重写build_lstm函数,TF1.11测试通过 def lstm_cell(): lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size) return tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)

def build_lstm(lstm_size, num_layers, batch_size, keep_prob): ''' 构建lstm层
keep_prob lstm_size: lstm隐层中结点数目 num_layers: lstm的隐层数目 batch_size: batch_size ''' cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)]) initial_state = cell.zero_state(batch_size, tf.float32)
return cell, initial_state

stud2008 avatar Mar 08 '19 07:03 stud2008

按以下代码重写build_lstm函数,TF1.11测试通过 def lstm_cell(): lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size) return tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)

def build_lstm(lstm_size, num_layers, batch_size, keep_prob): ''' 构建lstm层 keep_prob lstm_size: lstm隐层中结点数目 num_layers: lstm的隐层数目 batch_size: batch_size ''' cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)]) initial_state = cell.zero_state(batch_size, tf.float32) return cell, initial_state

没错,因为原始代码相当于只构建了一个lstm单元lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size),然后在cell = tf.contrib.rnn.MultiRNNCell([drop for _ in range(num_layers)])这里相当于将刚刚构建的lstm单元重复添加了,所以需要定义一个函数来保证每次添加lstm单元时都是新生成的

li-aolong avatar Aug 15 '19 09:08 li-aolong