pytorch-sentiment-neuron icon indicating copy to clipboard operation
pytorch-sentiment-neuron copied to clipboard

Code to load weights

Open ahirner opened this issue 7 years ago • 7 comments

The code in models.py constructs the graph in a very sleek way. Is it possible to see how you transformed the weights into mlstm_ns.pt too?

ahirner avatar May 06 '17 15:05 ahirner

This is the code I used to load the weight from numpy files :+1:

embed.weight.data = torch.from_numpy(np.load("embd.npy"))
rnn.h2o.weight.data = torch.from_numpy(np.load("w.npy")).t()
rnn.h2o.bias.data = torch.from_numpy(np.load("b.npy"))
rnn.layers[0].wx.weight.data = torch.from_numpy(np.load("wx.npy")).t()
rnn.layers[0].wh.weight.data = torch.from_numpy(np.load("wh.npy")).t()
rnn.layers[0].wh.bias.data = torch.from_numpy(np.load("b0.npy"))
rnn.layers[0].wmx.weight.data = torch.from_numpy(np.load("wmx.npy")).t()
rnn.layers[0].wmh.weight.data = torch.from_numpy(np.load("wmh.npy")).t()

guillitte avatar May 06 '17 17:05 guillitte

Thx for reverse engineering and sharing!

ahirner avatar May 06 '17 18:05 ahirner

I haded the lm.py file allowing to retrain the model on new data. It was used to create the model and load the weights.

guillitte avatar May 06 '17 19:05 guillitte

I tried to map the the original TF variables to the original .npy files to your .npy files. Is this mainly correct? Also, I wouldn't know how 14 and 15.npy would be used if they were (b0?) and which file corresponds to gmh in the pytorch version.

#Embedding for ASCII one-hot
embd = 0.npy = embed.npy

#State
wh = 1.npy = wh.npy
wmx = concat(2:6.npy) = wmx.npy
wmh = empty? = wmh.npy

gx = empty
gh = empty
gmx = empty
gmh = 7.npy = ?

wx = 8.npy = wx.npy
wh = 9.npy = wh.npy
wmx = 10.npy = wmx.npy
wmh = 11.npy = wmh.npy

#Fully connected
w = 12.npy = w.npy

ahirner avatar May 07 '17 08:05 ahirner

Things are more complicated than this, because the tf model is using l2 regularization. Pytorch handles this differently. This is why I had to hack the tensorflow model to produce the different npy files.

guillitte avatar May 13 '17 12:05 guillitte

Interesting, I assume you extracted the variables from a live TF graph then. I also found that L2 is added in pytorch's optimizer (usually?) and suspect that was the difference you talk about. Thanks!

ahirner avatar May 14 '17 09:05 ahirner

Forgive me, it is not L2 regularization but weights normalization which is the problem. And yes, I extracted the variables with tf code.

guillitte avatar May 14 '17 11:05 guillitte