practical-pytorch
practical-pytorch copied to clipboard
Best way to save the model in "RNN Classification"
Hi guys I am a newbie in pytorch. I find that pytorch has simple way to save our model. When practicing tutorial in RNN Classification, I found a problem to save the model. To save the model I do a simple way by execute torch.save(rnn,'char-rnn-classification.pt') , then as in the predict.py files, I load the model by rnn = torch.load('char-rnn-classification.pt'). This mechanism should be save the entire model from network until the weights. However when I execute it, it successfully save the model file but when I predicting the input in testing phase I got this error. Anybody know how to save the model correctly?
python predict.py Satoshi
Traceback (most recent call last):
File "predict.py", line 32, in <module>
predict(sys.argv[1])
File "predict.py", line 17, in predict
output = evaluate(Variable(lineToTensor(line)))
File "predict.py", line 12, in evaluate
output, hidden = rnn(line_tensor[i], hidden)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "/media/mspl/ext1/Desktop/Andi/pytorch/practical-pytorch-master/char-rnn-classification/model.py", line 17, in forward
hidden = self.i2h(combined)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/linear.py", line 54, in forward
return self._backend.Linear()(input, self.weight, self.bias)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/_functions/linear.py", line 10, in forward
output.addmm_(0, 1, input, weight.t())
RuntimeError: size mismatch, m1: [1 x 186], m2: [185 x 128] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1237
Make sure your dataset is exactly the same, because lineToTensor relies on the number and order of characters in all_characters to create the input tensors. Another solution is to make that function and character list directly attached to the model.