PyTorchZeroToAll icon indicating copy to clipboard operation
PyTorchZeroToAll copied to clipboard

Does `batch_first` really also work on the shape of RNN's hidden vector?

Open iamkissg opened this issue 7 years ago • 1 comments

Hi, in your code for 12_1_rnn_basics, you mentioned that

#(batch, num_layers * num_directions, hidden_size) for batch_first=True

However, in pytorch's docs, it just said:

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)

So I got confused which one was right? Then I printed the shape of RNN's hidden vector after executing this line, which turned out to be [1, 3, 2].

So could you please tell me which one is right, and explain why the output is [1, 3, 2]?

Thank you very much.

iamkissg avatar Dec 28 '17 17:12 iamkissg

It's a bug.

This is correct. #(num_layers * num_directions, batch, hidden_size) for batch_first=True

Can you send me a PR?

On Fri, Dec 29, 2017 at 2:15 AM, Engine Chen [email protected] wrote:

Hi, in your code for 12_1_rnn_basics https://github.com/hunkim/PyTorchZeroToAll/blob/master/12_1_rnn_basics.py#L15, you mentioned that

#(batch, num_layers * num_directions, hidden_size) for batch_first=True

However, in pytorch's docs http://pytorch.org/docs/master/nn.html?highlight=batch_first#rnn, it just said:

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature)

So I got confused which one was right? Then I printed the shape of RNN's hidden vector after executing this line https://github.com/hunkim/PyTorchZeroToAll/blob/59c86cc42b305807789291564501a667689fb812/12_1_rnn_basics.py#L45, which turned out to be [1, 3, 2].

So could you please tell me which one is right, and explain why the output is [1, 3, 2]?

Thank you very much.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/hunkim/PyTorchZeroToAll/issues/16, or mute the thread https://github.com/notifications/unsubscribe-auth/AA3DVx6zSHsq5bWoSLo2yDew_CVQtwBsks5tE8zGgaJpZM4ROf8K .

hunkim avatar Dec 28 '17 22:12 hunkim