MultiTurnDialogZoo icon indicating copy to clipboard operation
MultiTurnDialogZoo copied to clipboard

When can models stop training?

Open ZihaoW123 opened this issue 3 years ago • 4 comments

I observed that when the Test PPL of the model stopped falling, the other test results would still rise. For example, when the epoch is 22, Valid Loss and Test PPL may be the lowest, but after the epoch is greater than 22, Bleu and Distinct results will continue to rise. According to the experimental steps, theoretically, model parameters should be saved at the lowest point of the Valid dataset Loss. Then the model's evaluation results in Test Dataset can be regarded as the final representation. Is that right?

ZihaoW123 avatar Sep 23 '20 03:09 ZihaoW123

Typically, the checkpoint need to be saved when the lowest losses (ppl) are achieved. But during my experiments, I find the model' performance can be further improved by training more epoches. In my opinion, I think the generative task is a little bit different from the classification task, and you can try to train longer even if the lowest loss are obtained.

gmftbyGMFTBY avatar Sep 23 '20 03:09 gmftbyGMFTBY

Thank you very much for your project. The code is beautiful, the logic is clear, and I've learned a lot from your project. Will you continue to open source GAN-based dialog models, RL-based conversation models, or some transformer-based dialog models? I'm looking forward to it.

Besides, I see in the data_loader.py, the order of loading data is fixed first. Although random is added to each batch thereafter, global randomness can not be achieved.

turns = [len(dialog) for dialog in src_dataset]
turnidx = np.argsort(turns)
# sort by the lengrh of the turns
src_dataset = [src_dataset[idx] for idx in turnidx]
tgt_dataset = [tgt_dataset[idx] for idx in turnidx]

...

shuffleidx = np.arange(0, len(sbatch))
np.random.shuffle(shuffleidx)
sbatch = [sbatch[idx] for idx in shuffleidx]
tbatch = [tbatch[idx] for idx in shuffleidx]

Does that make a difference? Does the model remember the order of the data leading to overfitting?

After I add the global random, the MReCoSa model Valid Loss curve is better, and the Test PPL can also be reduced Some of the code changes are as follows:

    turns = [len(dialog) for dialog in src_dataset]
    fidx, bidx = 0, 0
    fidx_bidx_list = []
    while fidx < len(src_dataset):
        bidx = fidx + batch_size
        head = turns[fidx]
        cidx = 10000
        for p, i in enumerate(turns[fidx:bidx]):
            if i != head:
                cidx = p
                break
        cidx = fidx + cidx
        bidx = min(bidx, cidx)
        # print(fidx, bidx)

        # batch, [batch, turns, lengths], [batch, lengths]
        # shuffle
        # sbatch= src_dataset[fidx:bidx]
        if bidx - fidx <= plus:
            fidx = bidx
            continue
        fidx_bidx_list.append([fidx, bidx])
        fidx = bidx
    shuffleidx = np.arange(0, len(fidx_bidx_list))
    np.random.shuffle(shuffleidx)
    fidx_bidx_list_ = [fidx_bidx_list[i] for i in shuffleidx]

    for fidx, bidx in fidx_bidx_list_:
        sbatch, tbatch = src_dataset[fidx:bidx], tgt_dataset[fidx:bidx]
        shuffleidx = np.arange(0, len(sbatch))
        np.random.shuffle(shuffleidx)
        sbatch = [sbatch[idx] for idx in shuffleidx]
        tbatch = [tbatch[idx] for idx in shuffleidx]

ZihaoW123 avatar Sep 23 '20 10:09 ZihaoW123

Thank you so much for your attention to this repo.

  1. As for the GAN-based model, it will take me some time to implement it, which may take about one month for me.
  2. As for the transformer-based model, you can check my other repo OpenDialog, which contains some transformer-based retrieval and generative dialog models.
  3. Thank you for your improvement in this repo, I think you are doing a great job, and I will consider listening to your suggestions.

gmftbyGMFTBY avatar Sep 23 '20 13:09 gmftbyGMFTBY

I'm a little confused about some code in In DSHRED.py: ` def forward(self, inpt, hidden=None): # inpt: [turn_len, batch, input_size] # hidden # ALSO RETURN THE STATIC ATTENTION if not hidden: hidden = torch.randn(2, inpt.shape[1], self.hidden_size) if torch.cuda.is_available(): hidden = hidden.cuda()

    # inpt = self.drop(inpt)
    # outpput: [Seq, batch, 2 * hidden_size]
    output, hidden = self.gru(inpt, hidden)
    # output: [seq, batch, hidden_size]
    output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:]
    
    # static attention
    **static_attn = self.attn(output[0].unsqueeze(0), output)**
    static_attn = static_attn.bmm(output.transpose(0, 1))
    static_attn = static_attn.transpose(0, 1)    # [1, batch, hidden]

    # hidden: [1, batch, hidden_size]
    # hidden = hidden.squeeze(0)    # [batch, hidden_size]
    hidden = torch.tanh(hidden)
    return static_attn, output, hidden`

Why is the bold line of code self.attn(output[0].unsqueeze(0), output) and not self.attn(output[-1].unsqueeze(0), output)? In DSHRED paper, the static attention mechanism calculates the importance of each utterance as e_i: e_i = V tanh(Wh_i +Uh_s), where h_i and h_s denote the representations of hidden state of the i-th and the last utterance in a conversation. I think h_s should be output[-1] instead of output[0]. Is that right? Thanks.

ZihaoW123 avatar Mar 11 '21 07:03 ZihaoW123