Sentence-VAE
Sentence-VAE copied to clipboard
Confusion about the batch in model.py(I think batch should be the second dimension.)
I read the code very carefully. I am confused about the 65th line of model.py. I think the second dimension of "hidden" is batch, not the first one. Even though the encoder have been set with "batch_first=True”, the output will be have "batch" in the first dimension, but the hidden state is not. I have test this on my own computer.
Of course, the code can run without problems. I just feel confused since the two dimensions are mixed up in the code. Is there any one can help me with it?
I found the same issue. A simple example:
> t = torch.Tensor([[[1,2,3,4],[5,6,7,8],[9,0,1,2]],[[2,2,3,4],[5,6,7,8],[9,0,1,2]]])
> t
tensor([[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 0., 1., 2.]],
[[2., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 0., 1., 2.]]])
> t.shape
torch.Size([2, 3, 4])
# num_layers * 2 (if bidirectional), batch_size, max sequence lenth
> t.view(3,4*2)
tensor([[1., 2., 3., 4., 5., 6., 7., 8.],
[9., 0., 1., 2., 2., 2., 3., 4.],
[5., 6., 7., 8., 9., 0., 1., 2.]])
> torch.cat([t[i] for i in range(2)],dim=1)
tensor([[1., 2., 3., 4., 2., 2., 3., 4.],
[5., 6., 7., 8., 5., 6., 7., 8.],
[9., 0., 1., 2., 9., 0., 1., 2.]])
I agree with you. I found that, too.
I fix this issue by using the function 'permute' like the following:
import torch
factor = 2
bs = 4
hs = 3
eh = torch.randn(factor, bs, hs)
eh_p = eh.permute(1,0,2)
#True
eh_t = eh_p.reshape(bs, factor*hs)
#False
eh_t2 = eh.reshape(bs, factor*hs)
Here I replaced 'view' by 'reshape' as suggested in the error