pytorch-nlp-notebooks
pytorch-nlp-notebooks copied to clipboard
self.decoder(output[:, :, -1]).squeeze() should be self.decoder(output[:, -1, :]).squeeze()
def forward(self, inputs): # Avoid breaking if the last batch has a different size batch_size = inputs.size(0) if batch_size != self.batch_size: self.batch_size = batch_size
encoded = self.encoder(inputs)
output, hidden = self.rnn(encoded, self.init_hidden())
output = self.decoder(output[:, :, -1]).squeeze()
return output