stable-baselines icon indicating copy to clipboard operation
stable-baselines copied to clipboard

Is it possible to save trained model as TF saved_model format? If so how?

Open crobarcro opened this issue 5 years ago • 7 comments

This is a question, but I don't know how to add the question tag.

My question is about exporting my model for use in other systems. The ultimate goal is to get it into ONNX format. I intend to achieve this using tf2onnx. However, the preferred input format for tf2onnx is a tensorflow saved_model format. Therefore I would like to export to this format.

Is this possible, and if so, how?

I tried the following:

import gym
import tensorflow as tf

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1


#with tf.Graph().as_default():
#  with tf.Session() as sess:

env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = PPO1(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)

#tf.global_variables_initializer().run(session=model.sess)

init = tf.global_variables_initializer()

model.sess.run(init)

tf.train.Saver ()

saver.save(model.sess, '/my/redacted/save/dir')

But this fails with the following error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    302         self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 303             fetch, allow_tensor=True, allow_operation=True))
    304       except TypeError as e:

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3795     with self._lock:
-> 3796       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3797 

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3879       if obj.graph is not self:
-> 3880         raise ValueError("Operation %s is not an element of this graph." % obj)
   3881       return obj

ValueError: Operation name: "init"
op: "NoOp"
 is not an element of this graph.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py in <module>()
     22 init = tf.global_variables_initializer()
     23 
---> 24 model.sess.run(init)
     25 
     26 tf.train.Saver ()

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1156     # Create a fetch handler to take care of the structure of fetches.
   1157     fetch_handler = _FetchHandler(
-> 1158         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1159 
   1160     # Run request and get response.

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    472     """
    473     with graph.as_default():
--> 474       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    475     self._fetches = []
    476     self._targets = []

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    272         if isinstance(fetch, tensor_type):
    273           fetches, contraction_fn = fetch_fn(fetch)
--> 274           return _ElementFetchMapper(fetches, contraction_fn)
    275     # Did not find anything.
    276     raise TypeError('Fetch argument %r has invalid type %r' % (fetch,

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    308       except ValueError as e:
    309         raise ValueError('Fetch argument %r cannot be interpreted as a '
--> 310                          'Tensor. (%s)' % (fetch, str(e)))
    311       except KeyError as e:
    312         raise ValueError('Fetch argument %r cannot be interpreted as a '

ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
 is not an element of this graph.)

I've tried various things, but I'm floundering about a bit. Can anyone confirm it's even possible and if so maybe give some pointers?

I think this would be generally useful for the community.

System Info Describe the characteristic of your environment:

Mint Linux 19.3

Everything was installed via pip3

Python version: 3.6.9 Tensorflow version: 1.14 Stable Baselines version: 2.9.0 tf2onnx version: 1.5.4

crobarcro avatar Feb 13 '20 11:02 crobarcro

Hello, Did you try taking at look at the doc on exporting models?

Btw, if you succeed, we would appreciate a PR that documents how to do it ;)

araffin avatar Feb 13 '20 11:02 araffin

I did look at this, yes, but while this has some pointers it doesn't quite have a full example anywhere. Actually I haven't yet tried the simple_save method. I will do this and report back.

If I get it to work I'd be happy to update the docs, even just for my own records.

crobarcro avatar Feb 13 '20 13:02 crobarcro

For simple_save I'm supposed to do something like this:

path = '/home/me/savepath'

inputs_dict = {
                "input1": 0,
                "input2": 1
            }

outputs_dict = {
    "output": 0
}

tf.saved_model.simple_save(
    model.sess, path, inputs_dict, outputs_dict
)


tf.saved_model.simple_save (model.sess, )

However, I don't have a clue what should go in the inputs and outputs dict here, any suggestions for what they might be for the cartpole model?

crobarcro avatar Feb 13 '20 14:02 crobarcro

@crobarcro

See the discussions linked behind the docs, e.g. this and this.

If you end up doing a PR for this, it could also include specific examples like these which show which variables are supposed to go as inputs/outputs.

Miffyli avatar Feb 13 '20 14:02 Miffyli

I got the model to export using the following (you can see in comments some of the other variables I tried):

path = '/home/me/savepath'

inputs_dict = {
                #"obs": model.policy_tf.obs_phmodel.policy.obs_ph
                #"obs": model.policy.obs_ph
                "obs": model.act_model.obs_ph
              }

outputs_dict = {
    #"action": model.policy.action_ph
    "action": model.action_ph
}

tf.saved_model.simple_save(
    model.sess, path, inputs_dict, outputs_dict
)

however, I've no idea yet whether this is actually correct. If anyone knows this is obviously wrong that would be helpful. Note I switched to PPO2.

I tried using the tensorflow summarize_graph tool (also described on the tf2onnx page). This is supposed to display the input and output names in the saved graph. However, I get the following when run on the .pb file I created:

$ /home/rcrozier/src/tensorflow-git/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph="saved_model.pb"
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:4: Interpreting non ascii codepoint 194.
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:4: Expected identifier, got: �
2020-02-14 10:34:19.643890: E tensorflow/tools/graph_transforms/summarize_graph_main.cc:320] Loading graph 'saved_model.pb' failed with Can't parse saved_model.pb as binary proto
	 (both text and binary parsing failed for file saved_model.pb)
2020-02-14 10:34:19.643985: E tensorflow/tools/graph_transforms/summarize_graph_main.cc:322] usage: /home/rcrozier/src/tensorflow-git/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
Flags:
	--in_graph=""                    	string	input graph file name
	--print_structure=false          	bool	whether to print the network connections of the graph

This incidentally required building tensorflow from the git sources (using the stupid Bazel build system), I actually tried both tensorflow 2.1 and tensorflow 1.14 just in case there was some incompatibility, but I got the same result.

I'm going to keep trying and will report back progress on this issue.

crobarcro avatar Feb 14 '20 10:02 crobarcro

Looking at the policy (in common/ folder), this should be: 'action': model.act_model._policy_proba (cf https://github.com/hill-a/stable-baselines/issues/474) which corresponds to the output of the policy.

action_ph is used for training (it is a batch of actions).

araffin avatar Feb 14 '20 12:02 araffin

Thanks, actually I had seen that issue and made the change after @Miffyli had pointed it out (actually I had skimmed over both issues previously, but missed the crucial details).

Anyway, I tried the suggested names, see the script below, which just attempts to save the model and then load it again with tensorflow:

import shutil, os
import gym
import tensorflow as tf

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2


#with tf.Graph().as_default():
#  with tf.Session() as sess:

env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)


containing_dir = os.path.dirname(os.path.realpath(__file__))

path = os.path.join(containing_dir, 'export_model_test')

shutil.rmtree (path, ignore_errors=True)

os.mkdir (path)

##########          using simple_save          #############
inputs_dict = {
               #"obs": model.policy_tf.obs_phmodel.policy.obs_ph
               #"obs": model.policy.obs_ph
               "obs": model.act_model.obs_ph
             }

outputs_dict = {
   #"action": model.policy.action_ph
   #"action": model.action_ph
   "action": model.act_model._policy_proba
}

tf.saved_model.simple_save(
   model.sess, path, inputs_dict, outputs_dict
)
#############################################################

##########          using tf.saved_model.loader         #############
# init = tf.global_variables_initializer()
# model.sess.run(init)
# saver = tf.train.Saver()
# saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
# tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)
#####################################################################

# ##########          using import/export meta_graph          #############
# meta_file = os.path.join(path, 'saved_model.meta')
# meta_graph_def = tf.train.export_meta_graph( filename = meta_file,
#                                              graph=model.graph,
#                                              graph_def=model.graph.as_graph_def() )
############################################################################

# I think I need to close the session to free any resources
model.sess.close ()

##########          using tf.saved_model.loader         #############
with tf.Session() as sess:
   tf.saved_model.loader.load(sess, [tf.saved_model.SERVING], path)

   graph = tf.get_default_graph()

   print(graph.get_operations())
#######################################################################

# ##########          using tf.saved_model.loader             #############
# restored_graph = tf.Graph()
# with restored_graph.as_default():
#     with tf.Session() as sess:
#         tf.saved_model.loader.load(
#             sess,
#             [tf.saved_model.SERVING],
#             path,
#         )
#         obs_placeholder = restored_graph.get_tensor_by_name('obs:0')
#
#         sess.run(prediction, feed_dict={
#             obs_placeholder: some_value,
#         })
############################################################################


# ##########          using import_meta_graph          #############
# with tf.Session() as sess:
#     new_saver = tf.train.import_meta_graph(meta_file)
#     new_saver.restore(sess, meta_file)
#
#     # sess.run(prediction, feed_dict={
#     #          obs_placeholder: some_value,
#     #      })
######################################################################

With this, however, I get the following output.

INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_model_test/saved_model.pb
WARNING:tensorflow:From /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py:67: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
[]

There are two comented out sections which represent alternative methods of saving/loading which I have found. If I try to save using

if I do this:

init = tf.global_variables_initializer()
model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)

I get

ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
 is not an element of this graph.)

If I do this:

#init = tf.global_variables_initializer()
# model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)

I get this:

ValueError: No variables to save

Exporting the meta graph seems to work, but when I load it there doesn't seem to be anything in the graph.

crobarcro avatar Feb 14 '20 13:02 crobarcro