zhihu
zhihu copied to clipboard
training_logits, targets维度不匹配
cost = tf.contrib.seq2seq.sequence_loss(training_logits, targets, masks)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [5632] vs. [6400]
就是我的target是(256,25) 可是输出得到的training_logits却是(256, 22, 358)358:词表数
我改了一下,这样就对了
def pad_batch_sentence(batch, max_length, pad_id):
# max_length = max([len(sentence) for sentence in batch])
return [sentence + [pad_id] * (max_length - len(sentence)) for sentence in batch]
def get_batches(sources, targets, batch_size):
for batch_i in range(0, len(sources) // batch_size):
start_i = batch_i * batch_size
# Slice the right amount for the batch
sources_batch = sources[start_i:start_i + batch_size]
targets_batch = targets[start_i:start_i + batch_size]
pad_idx = source_vocab_to_int.get("<PAD>")
sources_batch_pad = np.array(pad_batch_sentence(sources_batch, max_source_sentence_length, pad_idx))
targets_batch_pad = np.array(pad_batch_sentence(targets_batch, max_target_sentence_length, pad_idx))
# Need the lengths for the _lengths parameters
# 不应该是对pad过的batch做长度的计算,因为都是25
targets_lengths = []
for target in targets_batch_pad:
targets_lengths.append(len(target))
source_lengths = []
for source in sources_batch_pad:
source_lengths.append(len(source))
yield sources_batch_pad, targets_batch_pad, source_lengths, targets_lengths
可是这样传入的source_lengths都是(20,20,20...)
targets_lengths都是(25, 25, 25...)
我也觉得这块有点问题,这样source长度全是padding以后的最大长度。。
改了以后会报错。。。