onnx-tensorflow icon indicating copy to clipboard operation
onnx-tensorflow copied to clipboard

Train onnx model from another framework

Open letruongthanh3698 opened this issue 4 years ago • 2 comments

Hi everyone,

I am now testing the example/train_onnx_model.py on Google Colab with the onnx model generated from MATLAB Deep Learning Tool Box and it shows the error: `==> Train the model..

ValueError Traceback (most recent call last) in () 1 if name == "main": ----> 2 train_onnx_model() 3 run_onnx_model(trained_onnx_model)

2 frames in train_onnx_model() 42 feed_dict[training_flag_placeholder] = True 43 loss, accuracy, _ = sess.run([loss_op, eval_op, opt_op], ---> 44 feed_dict=feed_dict) 45 if (step % 100) == 0: 46 print('Epoch {}, train step {}, loss:{}, accuracy:{}'.format(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 966 try: 967 result = self._run(None, fetches, feed_dict, options_ptr, --> 968 run_metadata_ptr) 969 if run_metadata: 970 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1165 'Cannot feed value of shape %r for Tensor %r, ' 1166 'which has shape %r' % -> 1167 (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 1168 if not self.graph.is_feedable(subfeed_t): 1169 raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (32, 28, 28, 1) for Tensor 'imageinput_Mean/Read/ReadVariableOp:0', which has shape '(1, 1, 1, 1)'`

Do I have to train by tensorflow and retrain by tensorflow or I can train with different type of framework?

Can anyone help me?

Thank you.

letruongthanh3698 avatar Jul 01 '21 10:07 letruongthanh3698

I don't think the model has to be initially trained in Tensorflow. @chudegao maybe you can take a look. Thanks.

chinhuang007 avatar Jul 07 '21 00:07 chinhuang007

The onnx model can be exported from other frameworks. I tried using onnx model from both pytorch and tensorflow. Just make sure the model's input and feed_dict is consist. From the error message, I guess the onnx model's input shape should be [1,1,1,1] and you are trying to feed data with shape[32,28,28,1].

chudegao avatar Jul 07 '21 01:07 chudegao