onnx-tensorflow
onnx-tensorflow copied to clipboard
Cannot export rnn together with vgg.
Describe the bug
I need export a torch model which consists of vgg and rnn into tensorflow. For this purpose I use PyTorch —> ONNX —> Tensorflow approach.
However, I get the following error message at 'model = prepare(onnx_model)':
ValueError: Input size (depth of inputs) must be accessible via shape inference, but saw value None.
I can export vgg and rnn respectively.
To Reproduce
Here's the full code:
`def export(model_path): # Just check VGG2L class TestModel(torch.nn.Module): def load_model(self, model): self.model = model def forward(self, feat, states): return self.model.enc.forward(feat, states)
# Load model
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
test_model = TestModel()
test_model.load_model(model)
# init input
length = 100
input_x = torch.ones([1, length, 83], dtype=torch.float) * 1.3
h0_in = np.zeros((6,1,512),dtype = float)
c0_in = np.zeros((6,1,512),dtype = float)
h0_in = torch.FloatTensor(h0_in)
c0_in = torch.FloatTensor(c0_in)
states = (h0_in, c0_in)
data_input = (input_x, states)
torch.onnx.export(test_model, data_input, 'encoder.onnx',
opset_version=10,
verbose=True,
do_constant_folding=True,
input_names=['input', 'h0_in', 'c0_in'],
output_names=['output', 'h0_out', 'c0_out'],
dynamic_axes={'input':{1:'sequence'},
'output': {1:'sequence'},
}
)
# load onnx model
onnx_model = onnx.load('encoder.onnx')
model = prepare(onnx_model)
# export pb
model.export_graph('encoder.pb')
`
Python, ONNX, ONNX-TF, Tensorflow version
This section can be obtained by running get_version.py
from util folder.
- Python version: 3.7
- ONNX version: 1.8.0
- ONNX-TF version: 1.7.0 (tf-1.x)
- Tensorflow version: 1.15
Additional context
The model is like this:
def forward(self, xs_pad, prev_state): xs_pad = self.vgg(xs_pad) xs_pad, states = self.rnn(xs_pad, prev_state) return xs_pad, states
The shape inference error is typically coming from onnx checker. That indicates the onnx model has some issues. Please try to run a simple test with your onnx file to verify.
`import onnx
onnx_path = 'your_model.onnx'
model = onnx.load(onnx_path) onnx.checker.check_model(model) `
Hi @chinhuang007
Sorry for getting back to you late. I tried the verification, and it shows that the onnx has no issues.
Here is the onnx model link:
https://www.dropbox.com/s/5hk2wpb56e2kn0m/enc.onnx?dl=0
Could you please have a look?