onnx-tensorflow icon indicating copy to clipboard operation
onnx-tensorflow copied to clipboard

Unable to export graph for ONNX LSTM

Open jumelet opened this issue 3 years ago • 6 comments

Describe the bug

I'm trying to convert a PyTorch LSTM model to Tensorflow via ONNX. The conversion from PyTorch to ONNX runs smoothly without problems, and obtaining a tf_rep via onnx.backend.prepare works as well. However, I run into trouble when trying to export the model as a tensorflow graph via the export_graph method.

Oddly enough, I get 2 different errors depending on the number of layers in the LSTM: 1 layer:

WARNING:tensorflow:`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class is equivalent as `tf.keras.layers.StackedRNNCells`, and will be replaced by that in Tensorflow 2.0.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-2-e714aad9a419> in <module>
     23 
     24 tf_rep = onnx_tf.backend.prepare(onnx_model)
---> 25 tf_rep.export_graph("lstm.pb")

~/miniconda3/lib/python3.8/site-packages/onnx_tf/backend_rep.py in export_graph(self, path)
    127     """
    128     self.tf_module.is_export = True
--> 129     tf.saved_model.save(
    130         self.tf_module,
    131         path,

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
   1191   # pylint: enable=line-too-long
   1192   metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1193   save_and_return_nodes(obj, export_dir, signatures, options)
   1194   metrics.IncrementWrite()
   1195 

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
   1226 
   1227   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1228       _build_meta_graph(obj, signatures, options, meta_graph_def))
   1229   saved_model.saved_model_schema_version = (
   1230       pywrap_libexport.SAVED_MODEL_SCHEMA_VERSION)

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1397 
   1398   with save_context.save_context(options):
-> 1399     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1349                                 wrapped_functions)
   1350   object_saver = util.TrackableSaver(checkpoint_graph_view)
-> 1351   asset_info, exported_graph = _fill_meta_graph_def(
   1352       meta_graph_def, saveable_view, signatures,
   1353       options.namespace_whitelist, options.experimental_custom_gradients)

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist, save_custom_gradients)
    855 
    856   with exported_graph.as_default():
--> 857     signatures = _generate_signatures(signature_functions, resource_map)
    858     for concrete_function in saveable_view.concrete_functions:
    859       concrete_function.add_to_graph()

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _generate_signatures(signature_functions, resource_map)
    598         _map_function_arguments_to_created_inputs(argument_inputs,
    599                                                   signature_key, function.name))
--> 600     outputs = _call_function_with_mapped_captures(
    601         function, mapped_inputs, resource_map)
    602     signatures[signature_key] = signature_def_utils.build_signature_def(

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _call_function_with_mapped_captures(function, args, resource_map)
    550 def _call_function_with_mapped_captures(function, args, resource_map):
    551   """Calls `function` in the exported graph, using mapped resource captures."""
--> 552   export_captures = _map_captures_to_created_tensors(function.graph.captures,
    553                                                      resource_map)
    554   # Calls the function quite directly, since we have new captured resource

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _map_captures_to_created_tensors(original_captures, resource_map)
    466           if isinstance(secondary_referrer, base.Trackable):
    467             trackable_referrers.append(secondary_referrer)
--> 468       raise AssertionError(
    469           ("Tried to export a function which references untracked resource {}. "
    470            "TensorFlow objects (e.g. tf.Variable) captured by functions must "

AssertionError: Tried to export a function which references untracked resource Tensor("3139:0", shape=(), dtype=resource). TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops):
<tf.Variable 'lstm_kernel_lstm_2:0' shape=(None, None) dtype=float32>

Multi layer:

WARNING:tensorflow:`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class is equivalent as `tf.keras.layers.StackedRNNCells`, and will be replaced by that in Tensorflow 2.0.

...

~/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/rnn.py:764 _dynamic_rnn_loop
        raise ValueError(

ValueError: Input size (depth of inputs) must be accessible via shape inference, but saw value None.

When using a uni-directional LSTM I get the same issue as well, and in that case a multi-layer LSTM yields the same error as for the 1-layer Bi-LSTM ("... untracked resource Tensor").

To Reproduce

import torch
import torch.nn as nn
import onnx
import onnx_tf

layer_count = 1

model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()

with torch.no_grad():
    input = torch.randn(5, 3, 10)
    h0 = torch.randn(layer_count * 2, 3, 20)
    c0 = torch.randn(layer_count * 2, 3, 20)
    output, (hn, cn) = model(input, (h0, c0))

torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
                input_names=['input', 'h0', 'c0'],
                output_names=['output', 'hn', 'cn'],
                dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
onnx_model = onnx.load('lstm.onnx')


tf_rep = onnx_tf.backend.prepare(onnx_model)   
tf_rep.export_graph("lstm.pb")  # error arises here

ONNX model file

The 1-layer LSTM ONNX file can be found here.

The 4-layer LSTM ONNX file can be found here.

Python, ONNX, ONNX-TF, Tensorflow version

  • Python version: 3.8.5
  • ONNX version: 1.10.1
  • ONNX-TF version: 1.9.0
  • Tensorflow version: 2.6.0
  • PyTorch version: 1.9.0

I run my scripts in a Jupyter notebook on CPU, on Ubuntu 20.04.

jumelet avatar Sep 07 '21 11:09 jumelet

To test this in an additional environment I set up the script in a Google Colab file. Weirdly enough the script runs correctly when I run it the first time, with a Tensorflow graph being created and all that.

However, running it again (with e.g. different model parameters) yields the untracked resource Tensor error again, only after restarting the kernel I can do the export_graph again.

You can check it for yourself in this colab file: https://colab.research.google.com/drive/1xnDDSmnztO63aNyc4v3NKj5gea7ZnN7O?usp=sharing

jumelet avatar Sep 07 '21 12:09 jumelet

Hi @jumelet I have met the similar problem. And I can convert onnx to tensorflow.pb model 'correctly', but can't inference the tensorflow.pb model. Can you inference the model now?

Best regards

zc1616 avatar Sep 24 '21 04:09 zc1616

I ended up dropping this entire approach and directly implemented my model in Keras haha. It's unfortunate it is such a hassle to get this working.

jumelet avatar Sep 24 '21 09:09 jumelet

Thanks for your reply! Do you define in tensorflow and retrain Or directly define in Keras and assign onnx weights to tensorflow? Hope you can run it as soon as possible.

zc1616 avatar Sep 24 '21 09:09 zc1616

In my case I had a very simple model of the form: Embedding -> LSTM -> Linear decoder. So what I did is initialise a model of that form in Keras, and than cast the torch weights directly to numpy and set the Keras weights. So no ONNX involved at all.

This is far from an ideal solution though, as it requires you to manually define the architecture in Keras. But it did the job for me for now, at least I could proceed with what I wanted to do in Keras/Tensorflow.

Here's the code I used for casting the torch model to a Keras model:

import keras
import keras.layers as layers
import tensorflow as tf
import torch

# https://github.com/tensorflow/tensorflow/issues/38942
tf.keras.backend.set_image_data_format("channels_last")


class KerasModel(keras.Model):
    def __init__(self, torch_model):
        super().__init__()
        vocab_size = torch_model.encoder.num_embeddings
        emb_dim = torch_model.encoder.embedding_dim
        nhid = torch_model.lstm.hidden_size

        keras_model = keras.Sequential()

        keras_model.add(layers.Embedding(input_dim=vocab_size, output_dim=emb_dim))
        keras_model.add(layers.LSTM(nhid))
        keras_model.add(layers.Dense(torch_model.decoder.out_features))

        keras_model.layers[0].set_weights([torch_model.encoder.weight.numpy()])

        keras_model.layers[1].set_weights([
            torch_model.lstm.weight_ih_l0.detach().numpy().T,
            torch_model.lstm.weight_hh_l0.detach().numpy().T,
            torch_model.lstm.bias_hh_l0.detach().numpy() + torch_model.lstm.bias_ih_l0.detach().numpy(),
        ])

        keras_model.layers[2].set_weights([
            torch_model.decoder.weight.detach().numpy().T,
            torch_model.decoder.bias.detach().numpy()
        ])
        
        self.encoder = keras_model.layers[0]
        self.lstm = keras_model.layers[1]
        self.decoder = keras_model.layers[2]
        
    def call(self, input_ids=None, input_embeds=None):
        assert input_ids is not None or input_embeds is not None
        
        if input_embeds is None:
            input_embeds = self.encoder(input_ids)
        if input_embeds.ndim == 2:
            input_embeds = input_embeds[np.newaxis, ...]
        
        hidden = self.lstm(input_embeds)
        logits = self.decoder(hidden)
        
        return logits

jumelet avatar Sep 24 '21 09:09 jumelet

Good luck!

zc1616 avatar Sep 24 '21 10:09 zc1616