OpenNMT-py
OpenNMT-py copied to clipboard
Possible dimension transformation bug in RNNEncoder Bridge?
Hey! Thanks for writing this wonderful package. I am currently reading the code for onmt/encoders/rnn_encoder.py, and I do not quite understand the implementation on line 106 and 107:
size = states.size() result = linear(states.view(-1, self.total_hidden_dim)) return F.relu(result).view(size)
If I am understanding correctly, the states
variable is of dimension
(num_layers x batch_size x encoder_rnn_size)
Then it flattened the results and apply a linear transformation for each element in the batch. However, .view((-1, total_hidden_dim))
does not seem to flatten the result for each batch and will probably mix elements from the same batch into one row.
e.g. suppose num_layers = 2, batch_size=5, encoder_size = 3, (some toy examples that I used to debug)
states =
tensor([[[-0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113]], [[-0.0349, -0.0001, -0.0132], [-0.0351, -0.0002, -0.0131], [-0.0350, -0.0002, -0.0133], [-0.0351, -0.0002, -0.0131], [-0.0350, -0.0002, -0.0133]]], grad_fn=<StackBackward>)
where [-0.0028, 0.0173, 0.0113] and [-0.0349, -0.0001, -0.0132] belongs to the same element. but after the view((-1, total_hidden_dim)) operation, it becomes:
tensor([[-0.0028, 0.0173, 0.0113, -0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113, -0.0028, 0.0173, 0.0113], [-0.0028, 0.0173, 0.0113, -0.0349, -0.0001, -0.0132], [-0.0351, -0.0002, -0.0131, -0.0350, -0.0002, -0.0133], [-0.0351, -0.0002, -0.0131, -0.0350, -0.0002, -0.0133]], grad_fn=<ViewBackward>)
and each row does not represent an element in the batch (which I guess is not the intended result (?))
Would the correct implementation be:
states = states.permute(1, 0, 2).contiguous() size = states.size() result = linear(states.view(-1, self.total_hidden_dim)) result = F.relu(result).view(size) return result.permute((1, 0, 2)).contiguous()
(I do not know whether it is an efficient one though).
Thanks and looking forward to hearing from the developers!
same question
Indeed, this is a bug. The state should be permuted as shown above. It also seems the bridge does not work correctly when the RNN is bidirectional as the state is of shape (num_layers * num_directions, batch, hidden_size)
in this case.
Feel free to send a PR.
fixed in v3