Is it possible to save trained model as TF saved_model format? If so how?
This is a question, but I don't know how to add the question tag.
My question is about exporting my model for use in other systems. The ultimate goal is to get it into ONNX format. I intend to achieve this using tf2onnx. However, the preferred input format for tf2onnx is a tensorflow saved_model format. Therefore I would like to export to this format.
Is this possible, and if so, how?
I tried the following:
import gym
import tensorflow as tf
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1
#with tf.Graph().as_default():
# with tf.Session() as sess:
env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])
model = PPO1(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)
#tf.global_variables_initializer().run(session=model.sess)
init = tf.global_variables_initializer()
model.sess.run(init)
tf.train.Saver ()
saver.save(model.sess, '/my/redacted/save/dir')
But this fails with the following error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
302 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 303 fetch, allow_tensor=True, allow_operation=True))
304 except TypeError as e:
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3795 with self._lock:
-> 3796 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3797
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3879 if obj.graph is not self:
-> 3880 raise ValueError("Operation %s is not an element of this graph." % obj)
3881 return obj
ValueError: Operation name: "init"
op: "NoOp"
is not an element of this graph.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
/home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py in <module>()
22 init = tf.global_variables_initializer()
23
---> 24 model.sess.run(init)
25
26 tf.train.Saver ()
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
948 try:
949 result = self._run(None, fetches, feed_dict, options_ptr,
--> 950 run_metadata_ptr)
951 if run_metadata:
952 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1156 # Create a fetch handler to take care of the structure of fetches.
1157 fetch_handler = _FetchHandler(
-> 1158 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1159
1160 # Run request and get response.
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
472 """
473 with graph.as_default():
--> 474 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
475 self._fetches = []
476 self._targets = []
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
272 if isinstance(fetch, tensor_type):
273 fetches, contraction_fn = fetch_fn(fetch)
--> 274 return _ElementFetchMapper(fetches, contraction_fn)
275 # Did not find anything.
276 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
308 except ValueError as e:
309 raise ValueError('Fetch argument %r cannot be interpreted as a '
--> 310 'Tensor. (%s)' % (fetch, str(e)))
311 except KeyError as e:
312 raise ValueError('Fetch argument %r cannot be interpreted as a '
ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
is not an element of this graph.)
I've tried various things, but I'm floundering about a bit. Can anyone confirm it's even possible and if so maybe give some pointers?
I think this would be generally useful for the community.
System Info Describe the characteristic of your environment:
Mint Linux 19.3
Everything was installed via pip3
Python version: 3.6.9 Tensorflow version: 1.14 Stable Baselines version: 2.9.0 tf2onnx version: 1.5.4
Hello, Did you try taking at look at the doc on exporting models?
Btw, if you succeed, we would appreciate a PR that documents how to do it ;)
I did look at this, yes, but while this has some pointers it doesn't quite have a full example anywhere. Actually I haven't yet tried the simple_save method. I will do this and report back.
If I get it to work I'd be happy to update the docs, even just for my own records.
For simple_save I'm supposed to do something like this:
path = '/home/me/savepath'
inputs_dict = {
"input1": 0,
"input2": 1
}
outputs_dict = {
"output": 0
}
tf.saved_model.simple_save(
model.sess, path, inputs_dict, outputs_dict
)
tf.saved_model.simple_save (model.sess, )
However, I don't have a clue what should go in the inputs and outputs dict here, any suggestions for what they might be for the cartpole model?
@crobarcro
See the discussions linked behind the docs, e.g. this and this.
If you end up doing a PR for this, it could also include specific examples like these which show which variables are supposed to go as inputs/outputs.
I got the model to export using the following (you can see in comments some of the other variables I tried):
path = '/home/me/savepath'
inputs_dict = {
#"obs": model.policy_tf.obs_phmodel.policy.obs_ph
#"obs": model.policy.obs_ph
"obs": model.act_model.obs_ph
}
outputs_dict = {
#"action": model.policy.action_ph
"action": model.action_ph
}
tf.saved_model.simple_save(
model.sess, path, inputs_dict, outputs_dict
)
however, I've no idea yet whether this is actually correct. If anyone knows this is obviously wrong that would be helpful. Note I switched to PPO2.
I tried using the tensorflow summarize_graph tool (also described on the tf2onnx page). This is supposed to display the input and output names in the saved graph. However, I get the following when run on the .pb file I created:
$ /home/rcrozier/src/tensorflow-git/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph="saved_model.pb"
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:4: Interpreting non ascii codepoint 194.
[libprotobuf ERROR external/protobuf_archive/src/google/protobuf/text_format.cc:312] Error parsing text-format tensorflow.GraphDef: 1:4: Expected identifier, got: �
2020-02-14 10:34:19.643890: E tensorflow/tools/graph_transforms/summarize_graph_main.cc:320] Loading graph 'saved_model.pb' failed with Can't parse saved_model.pb as binary proto
(both text and binary parsing failed for file saved_model.pb)
2020-02-14 10:34:19.643985: E tensorflow/tools/graph_transforms/summarize_graph_main.cc:322] usage: /home/rcrozier/src/tensorflow-git/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
Flags:
--in_graph="" string input graph file name
--print_structure=false bool whether to print the network connections of the graph
This incidentally required building tensorflow from the git sources (using the stupid Bazel build system), I actually tried both tensorflow 2.1 and tensorflow 1.14 just in case there was some incompatibility, but I got the same result.
I'm going to keep trying and will report back progress on this issue.
Looking at the policy (in common/ folder), this should be:
'action': model.act_model._policy_proba (cf https://github.com/hill-a/stable-baselines/issues/474) which corresponds to the output of the policy.
action_ph is used for training (it is a batch of actions).
Thanks, actually I had seen that issue and made the change after @Miffyli had pointed it out (actually I had skimmed over both issues previously, but missed the crucial details).
Anyway, I tried the suggested names, see the script below, which just attempts to save the model and then load it again with tensorflow:
import shutil, os
import gym
import tensorflow as tf
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
#with tf.Graph().as_default():
# with tf.Session() as sess:
env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)
containing_dir = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(containing_dir, 'export_model_test')
shutil.rmtree (path, ignore_errors=True)
os.mkdir (path)
########## using simple_save #############
inputs_dict = {
#"obs": model.policy_tf.obs_phmodel.policy.obs_ph
#"obs": model.policy.obs_ph
"obs": model.act_model.obs_ph
}
outputs_dict = {
#"action": model.policy.action_ph
#"action": model.action_ph
"action": model.act_model._policy_proba
}
tf.saved_model.simple_save(
model.sess, path, inputs_dict, outputs_dict
)
#############################################################
########## using tf.saved_model.loader #############
# init = tf.global_variables_initializer()
# model.sess.run(init)
# saver = tf.train.Saver()
# saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
# tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)
#####################################################################
# ########## using import/export meta_graph #############
# meta_file = os.path.join(path, 'saved_model.meta')
# meta_graph_def = tf.train.export_meta_graph( filename = meta_file,
# graph=model.graph,
# graph_def=model.graph.as_graph_def() )
############################################################################
# I think I need to close the session to free any resources
model.sess.close ()
########## using tf.saved_model.loader #############
with tf.Session() as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.SERVING], path)
graph = tf.get_default_graph()
print(graph.get_operations())
#######################################################################
# ########## using tf.saved_model.loader #############
# restored_graph = tf.Graph()
# with restored_graph.as_default():
# with tf.Session() as sess:
# tf.saved_model.loader.load(
# sess,
# [tf.saved_model.SERVING],
# path,
# )
# obs_placeholder = restored_graph.get_tensor_by_name('obs:0')
#
# sess.run(prediction, feed_dict={
# obs_placeholder: some_value,
# })
############################################################################
# ########## using import_meta_graph #############
# with tf.Session() as sess:
# new_saver = tf.train.import_meta_graph(meta_file)
# new_saver.restore(sess, meta_file)
#
# # sess.run(prediction, feed_dict={
# # obs_placeholder: some_value,
# # })
######################################################################
With this, however, I get the following output.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_model_test/saved_model.pb
WARNING:tensorflow:From /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py:67: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
[]
There are two comented out sections which represent alternative methods of saving/loading which I have found. If I try to save using
if I do this:
init = tf.global_variables_initializer()
model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)
I get
ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
is not an element of this graph.)
If I do this:
#init = tf.global_variables_initializer()
# model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)
I get this:
ValueError: No variables to save
Exporting the meta graph seems to work, but when I load it there doesn't seem to be anything in the graph.