practical-pytorch
practical-pytorch copied to clipboard
Scheduled sampling in batched seq2seq
Hi, in the batched_seq2seq example, you have mentioned using Scheduled Sampling, however in the code there is no implementation of that. I am confused in how to use it in a batched setting. In case of non-batched mode, when there is no teacher forcing, we did:
topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
decoder_input = Variable(torch.LongTensor([[ni]]))
But for batched mode, the decoder input should be a full sequence of target sequences as given in the teacher forcing part:
decoder_input = target_variables[di]
How do I proceed to supply the previous output when my decoder itself is not batched?
it's seem like same way support for batched mode
topv, topi = decoder_output.data.topk(1, dim=1)
decoder_input = Variable(topi.squeeze(1))
@czs0x55aa How to implement this in batched mode since some decoder outputs are EOS and some are not
if ni == EOS_token: break