Adversarial-Learning-for-Neural-Dialogue-Generation-in-Tensorflow icon indicating copy to clipboard operation
Adversarial-Learning-for-Neural-Dialogue-Generation-in-Tensorflow copied to clipboard

gen.decode生成的结果有问题…

Open JiaheChang opened this issue 6 years ago • 4 comments

您好,我试图运行这个程序,因为我想要跑不同的数据,所以我改动了pre_data() 这部分的代码,但是在运行过程中,出现了一些问题。 当我运行gen.train()的时候,结果看上去是正常的,perplexity也在减小。但是我想用gen.decode()来生成discriminator的训练数据时,我发现生成的train.gen文件是空的。 然后我试图把out_logits打出来,结果发现每一行的预测数据相差都不大,然后在这部分代码中: for seq in out_logits: token = [] for t in seq: token.append(int(np.argmax(t, axis=0))) tokens.append(token) np.argmax(t,axis = 0),结果都是一个数,在我的vocal里正好对应了" ",所以导致train.gen里每一个预测出来的句子都是" "组成的.... 我现在完全不明白自己哪里错了……希望可以得到您的帮助……非常感谢……

JiaheChang avatar Aug 01 '17 00:08 JiaheChang

您好,而且我现在非常不明白的一点是,为什么在train.gen里每一行的结果都一样呢……

JiaheChang avatar Aug 02 '17 00:08 JiaheChang

在我实验中没有遇到这种情况,你可以先熟悉下代码。相信你自己可以解决这个问题的。我现在一直忙于其他项目,目前脱不开身。抱歉……

liuyuemaicha avatar Aug 03 '17 11:08 liuyuemaicha

是否是因为生成器预训练不够?train.gen是生成器生成的负例数据。如果生成器预训练不够的话,生成的效果自然好不到哪儿去。

imageslr avatar Nov 15 '18 08:11 imageslr

你好 @liuyuemaicha ,我对生成器的get_batch一直有个疑问,到现在还没解决:

如下所示,在type==2的时候,encoder_input应该是只有一个单词id的吧?是一个整数。但是下面又求了它的长度len(encoder_input),这是为什么?难道type=2的时候读入的数据和其他时候不一样吗?

train_data的shape是否是[bucket_num, batch_size, 2, encoder_size或decoder_size]

    for batch_i in xrange(batch_size):
        if type == 1:  # 返回桶内所有数据
            encoder_input, decoder_input = train_data[bucket_id]
        elif type == 2:  # 取桶内第一组,encoder_input是第batch_i个单词,encoder只有一个单词 # TODO 但下面是把它当数组用的,这里就报错了
            # print("disc_data[bucket_id]: ", disc_data[bucket_id][0])
            encoder_input_a, decoder_input = train_data[bucket_id][0]
            encoder_input = encoder_input_a[batch_i]
        elif type == 0:  # 桶内随机挑一组
            encoder_input, decoder_input = random.choice(train_data[bucket_id])
            # print("train en: %s, de: %s" %(encoder_input, decoder_input))

        # ...
        # 这里求了长度
        encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))

这个问题困扰我好久了,代码也仔细读了好几遍,还是看不懂这个。如果你能指点一下的话,感激不尽!

imageslr avatar Nov 15 '18 08:11 imageslr