recurrent-batch-normalization-pytorch
recurrent-batch-normalization-pytorch copied to clipboard
Not a number
Hey,
Thank you for providing a great example on how to implement custom LSTMs. I have a nan issue, however. I am trying to use your LSTM as a drop-in replacement for the pytorch LSTM. In the first iterations all the hidden states are 0 vectors and the values become nan very soon. Do you have any idea what might be causing the issue?
Thanks!
Using the default PyTorch reset_parameters() for initialization actually fixed it!
Then, can this issue be closed? If modification is needed, please let me know more about the details! 😃
Maybe implementing a more standard initialization method as default in the LSTMCell might be helpful for future users. I just have:
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
I tried to use the initialization scheme stated in the paper, however it seems there’s a problem in that. I will see the problem in days. Thanks!
I'll mention the issue on the Pytorch forum.
Is there any progress with that issue @kadarakos , @jihunchoi ?
I'm also facing this problem, the loss gets to nan after about 100 iterations on my dataset but changing the reset_parameters to the default pytorch method seems to help.
I tried getting to the source of the problem and got to this line in bnlstm "bn_wh = self.bn_hh(wh, time=time)" which is in the "forward" definition of "BNLSTMCell" which at some batch starts returning nan for some of the columns (because BNLSTMCell is called in a loop, after it returns nan for some of the columns the next iteration has nan for all of the columns and from this point everything is nan). I stopped troubleshooting and tried the method @kadarakos suggested and it helped.
Edit: I'm also getting nan error even after changing the reset_parameters. Sometimes during the training, the loss starts increasing rapidly (not increasing much more then the loss value at the starting time but the increase is noticeable, also the accuracy/error gets worse, and after a couple of iterations of training like this I get a nan loss.
Thanks.