pytorch2keras icon indicating copy to clipboard operation
pytorch2keras copied to clipboard

converted model has different weights than the original model

Open mohamedabdallah1996 opened this issue 3 years ago • 0 comments

I followed the steps mentioned in the repo to convert pytorch model to a keras model. the converted model actually have the same architecture but when I use predict from it I find different result. with some debugging I found that the weights of the converted model is different from the original.

what I did:

# load the model
model = torch.load('torch_model.pth')

# make dummy inputs
input_np1 = np.random.uniform(0, 1, (1, 3, 224, 224))
input_var1 = Variable(torch.FloatTensor(input_np1)).cuda()

input_np2 = np.random.uniform(0, 1, (1, 3, 224, 224))
input_var2 = Variable(torch.FloatTensor(input_np2)).cuda()

from pytorch2keras.converter import pytorch_to_keras

# convert to keras model (takes two inputs)
k_model = pytorch_to_keras(model, [input_var1, input_var2], [(3, 224, 224), (3, 224, 224)], verbose=True)  
# save keras model
k_model.save('keras_model.h5', include_optimizer=True)

I got the weights of the first layer from the original and the converted model to see the difference. the difference between them was:

array([[[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 3.27792168e-02, -2.67422020e-01, -2.12251216e-01, ...,
          -3.10191065e-01,  7.50064999e-02,  3.58290404e-01],
         [-2.15935856e-01, -5.00378549e-01,  3.05846363e-01, ...,
          -2.57776797e-01,  1.30452141e-01,  7.24201679e-01]],

        [[ 7.25828767e-01,  5.54713011e-01,  7.93337941e-01, ...,
          -9.43963081e-02,  3.59809697e-02, -3.87231857e-01],
         [-3.46688509e-01,  2.44602114e-02,  1.38779831e+00, ...,
          -3.14230382e-01,  1.26919508e-01,  1.92846417e-01],
         [-9.31575656e-01, -6.16762280e-01,  1.10421926e-01, ...,
          -7.26759732e-02,  2.26708114e-01,  6.88698769e-01]],

        [[ 1.21918261e+00,  4.36716199e-01,  8.56265873e-02, ...,
          -1.41018592e-02,  6.56825230e-02, -6.12603426e-01],
         [-3.33942957e-02,  8.63668323e-03,  2.41334200e-01, ...,
          -3.54576558e-02,  1.76371455e-01, -2.04365671e-01],
         [-9.11614120e-01, -1.25450209e-01, -6.84315085e-01, ...,
          -4.09535542e-02,  3.04877669e-01,  3.39422584e-01]]],


       [[[-7.58607984e-01, -2.87290990e-01, -5.81086755e-01, ...,
           4.04587388e-01, -1.10987470e-01,  2.89414525e-02],
         [ 3.42259519e-02, -6.22241199e-01, -1.13678348e+00, ...,
           1.85702033e-02,  7.92933404e-02,  2.37115413e-01],
         [ 5.56599855e-01, -6.21848583e-01,  6.58078939e-02, ...,
          -9.19592381e-02,  1.82882920e-01,  3.37736934e-01]],

        [[ 3.12462538e-01,  5.97781003e-01, -2.51014769e-01, ...,
           2.95660198e-01, -2.06212848e-01, -4.29961830e-01],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
         [ 1.77673489e-01, -7.19010532e-01, -7.99583316e-01, ...,
           1.13688819e-01,  2.36494333e-01,  2.93749720e-01]],

        [[ 6.33270502e-01,  4.88228977e-01,  3.93207967e-01, ...,
           1.79638714e-02, -1.51867911e-01, -4.47701693e-01],
         [-8.75645578e-02,  4.47005033e-02,  6.40237749e-01, ...,
          -6.42660633e-02, -1.96756348e-02, -1.52935743e-01],
         [-3.54735851e-01, -1.93259984e-01, -5.32801509e-01, ...,
          -4.32812832e-02,  2.36783803e-01,  1.05169579e-01]]],


       [[[-1.00324678e+00,  6.36623502e-02, -3.91472936e-01, ...,
           2.71878660e-01, -1.96134657e-01, -1.11598253e-01],
         [ 2.98305154e-01,  1.28533334e-01, -5.03629863e-01, ...,
           5.47121018e-02, -7.48402029e-02, -2.40997061e-01],
         [ 1.05954778e+00,  1.85320526e-01,  4.54589099e-01, ...,
          -2.71224417e-03,  1.58987790e-02, -4.78732646e-01]],

        [[-5.23205578e-01,  6.13211870e-01, -3.07142109e-01, ...,
           1.27416894e-01, -3.59254360e-01, -1.33371264e-01],
         [-9.01089385e-02,  6.74310029e-01,  1.59345567e-01, ...,
          -4.94227596e-02, -2.16818705e-01, -1.40813962e-01],
         [ 5.06306708e-01,  2.40937799e-01, -1.57638654e-01, ...,
           4.02870297e-04, -2.04735249e-02, -2.64339805e-01]],

        [[-1.47933692e-01, -5.98703288e-02,  2.29725987e-01, ...,
           4.36657965e-02, -3.20776463e-01,  1.39310062e-01],
         [-1.51570871e-01, -4.76778150e-02,  6.90440178e-01, ...,
           4.28784154e-02, -2.16310292e-01,  1.59170240e-01],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]]],
      dtype=float32)

how can I make sure that both models have the same weights and generate the same output?

mohamedabdallah1996 avatar Feb 03 '21 14:02 mohamedabdallah1996