pytorch-seq2seq
pytorch-seq2seq copied to clipboard
Question about tutorial 1 and 2 Decoder
Hello, I had a question about
prediction = self.fc_out(output)
In the decoder in tutorial 2, why is the output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), dim = 1)
as opposed to output = torch.cat((embedded.squeeze(0), output.squeeze(0), context.squeeze(0)), dim = 1)
?
In tutorial 2 text, it says
Thank you!
When we have a sequence length of one, which we do when decoding, then output == hidden
, as output
is the hidden state from all time-steps, and the hidden
is the hidden state from the final time-step, so if we have just a single time-step then the two are identical. So both of the code snippets in your question do the exact same thing.