practical-pytorch icon indicating copy to clipboard operation
practical-pytorch copied to clipboard

Scheduled sampling in batched seq2seq

Open koustuvsinha opened this issue 7 years ago • 2 comments

Hi, in the batched_seq2seq example, you have mentioned using Scheduled Sampling, however in the code there is no implementation of that. I am confused in how to use it in a batched setting. In case of non-batched mode, when there is no teacher forcing, we did:

topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
decoder_input = Variable(torch.LongTensor([[ni]]))

But for batched mode, the decoder input should be a full sequence of target sequences as given in the teacher forcing part:

decoder_input = target_variables[di]

How do I proceed to supply the previous output when my decoder itself is not batched?

koustuvsinha avatar Jul 27 '17 22:07 koustuvsinha

it's seem like same way support for batched mode

topv, topi = decoder_output.data.topk(1, dim=1)
decoder_input = Variable(topi.squeeze(1))

czs0x55aa avatar Aug 30 '17 13:08 czs0x55aa

@czs0x55aa How to implement this in batched mode since some decoder outputs are EOS and some are not if ni == EOS_token: break

zhongpeixiang avatar Dec 09 '17 09:12 zhongpeixiang