dqn-pytorch
dqn-pytorch copied to clipboard
Missing batch_first attribute for LSTM model.
By default the BATCH_SIZE = 32
.
Input to the LSTM from the CNN is of the shape (32, 64, 16)
.
The semantics of LSTM input are (seq_len, batch_size, input_size)
.
But the input format is (batch_size, seq_len, input_size)
.
To correct it batch_first
needs to be passed True while creating the LSTM model.
self.lstm = nn.LSTM(16, LSTM_MEMORY, 1, batch_first=True)
Astonishing part is model is still learning with this error.
wow! As you said it is astonishing haha :) I will fix this error soon!
Cool. I am fixing it as well. Will try for a PR in a couple of days.
good :) I will wait your PR. I think your contribution is more valuable than my fixing. After reviewing your PR, I will close this issue. :) good job!
When checking for batch sizes of input to the forward method of LSTM using print(x.shape)
,
the following is obtained.
torch.Size([32, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([31, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([31, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([30, 4, 84, 84])
torch.Size([32, 4, 84, 84])
This shows that batch size is changing with different inputs. This would break the code in the forward method as hidden_state
and cell_state
are initialized using method init_states
using BATCH_SIZE
which would become 32 (currently it uses 64 as batch size).
Any way to make batch size consistent in the input.
Also could you let me know the sources of inspiration for this code. That might help in fixing the issue quicker. Thanks.
Another issue that would need to be looked at would be batch_size
when using init_states
method to initialize hidden_state
and cell_state
.
Hidden/Cell state semantics: (n_layers, batch_size, hidden_size)
Since while training batch_size
would be 1 (one sample added at a time to replay memory), train_hidden_state
and train_cell_state
would use batch_size=1
for the dimension semantics, while dqn_hidden_state
and test_hidden_state
would be using batch_size=32
.
init_states
method would be modified to accept batch_size
as argument, and return relevant shaped hidden_state
and cell_state
.