rust
rust copied to clipboard
Saving a SavedModel after loading a graph
Hi, is there a way to save the SavedModel after training having loaded the graph from a SavedModel (as opposed to initialising the layers & vars like in examples/xor.rs
)?
It seems like the SavedModelBuilder requires a collection of variables & scope which doesn't seem straightforward to get just from an already existing graph.
I'm also interested in this, but have yet to found any possible workaround. It seems like while the signatures can be obtained from MetaGraphDef::signatures
, there's no way to get the collection of variables and scopes when a model is loaded via SavedModelBundle::load
or Session::from_saved_model
.
EDIT 2020/12/20: What I wanted to do is to build the model in Python, and then train the model from Rust, so I'm looking for a way to keep all the trained parameters, either using SavedModel or checkpoints. Things I've tried so far:
- Make the saving operations part of the saved model itself
- Include the saving operations (
tf.saved_model.save()
ormodel.save()
) in a@tf.function
from the Python side. This gives me error messages likesave is not supported inside a traced @tf.function
so it doesn't work. - Include the Keras
model.save_weights()
in a@tf.function
from the Python side. This gives me error messages likeRuntimeError: Cannot get session inside Tensorflow graph function
.
- Include the saving operations (
- Try to find a way to get the parameters need for
tensorflow::SavedModelBuilder
from the Rust side, but this is beyond my knowledge about TensorFlow. - Try to use the
tf.train.Saver
operation, as suggested in #30, but this seems to be no longer available in TensorFlow 2.0 and onwards.
Another interesting finding is that in the experimental C API (c_api_experimental.h
), while there's TF_CheckpointReader
type for reading checkpoints and TF_LoadSessionFromSavedModel
for reading saved models, there seems to be no functions available for saving states.
I found a (quite hacky and ugly) workaround to this.
- From the Python side, include a
tf.train.Checkpoint.write
call inside a@tf.function
. Unliketf.train.Checkpoint.save
, thewrite
function is written purely using tf operations, so it runs fine even in graph mode. - Save a concrete function of that together with the model. For example,
tf.saved_model.save(my_model, modelpath, signatures={ 'ckpt_write': my_model.ckpt_write, # Other functions to be saved ... })
- Run that
ckpt_write
function from the Rust side. For example,let mut args = tf::SessionRunArgs::new(); // name of output can be obtained by something like this: // bundle.meta_graph_def().signatures()["ckpt_write"].outputs()["output_0"].name() args.add_target( &graph .operation_by_name(&name_of_output_of_ckpt_write)? .unwrap(), ); session.run(&mut args)?;
This makes it possible to save the state of a network trained from the Rust side, but I still haven't found any way to restore from the checkpoint using Rust. The restoring can, however, be done by:
- Read the checkpoint using Python
tf.train.Checkpoint.read
function. - Save the whole model again (for Rust to read the whole model again using
SavedModelBundle::load
).
It has been 2 years, is there a nice way of doing things by now?
https://github.com/tensorflow/rust/blob/ae92c63703e9c197f0f1e21e9ce5f49a3f7d6bd0/src/saved_model.rs#L258
Well, in 2023, there's no support for other fields too.
What a pity that this essential functionality is not yet covered properly.