pyvarinf
pyvarinf copied to clipboard
Error with LSTMs on GPU
Hi @ctallec,
Great package! Thanks for writing it.
I recently found a bug that I wanted to seek your advice on. When I Variationalize an LSTM model, I see an error when porting it to CUDA. For example, here's the working version:
net = nn.LSTM(10, 10)
net.cuda()
This works fine. Then,
net = nn.LSTM(10, 10)
net = pyvarinf.Variationalize(net)
net.cuda()
This gives the bug:
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-14-0c80e37b52e0> in <module>()
1 net = nn.LSTM(10, 10)
2 net = pyvarinf.Variationalize(net)
----> 3 net.cuda()
3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in flatten_parameters(self)
94 Otherwise, it's a no-op.
95 """
---> 96 any_param = next(self.parameters()).data
97 if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
98 return
StopIteration:
Any idea what this is from? Maybe I shouldn't be using this with an RNN-like model?
Thanks! Miles
Hi @MilesCranmer,
Seems like a problem with the specific implementation of LSTMs, I will look into that asap and keep you updated. Thanks for reporting the issue :+1: