rust icon indicating copy to clipboard operation
rust copied to clipboard

Saving a SavedModel after loading a graph

Open lucamc9 opened this issue 3 years ago • 5 comments

Hi, is there a way to save the SavedModel after training having loaded the graph from a SavedModel (as opposed to initialising the layers & vars like in examples/xor.rs)?

It seems like the SavedModelBuilder requires a collection of variables & scope which doesn't seem straightforward to get just from an already existing graph.

lucamc9 avatar Oct 15 '20 14:10 lucamc9

I'm also interested in this, but have yet to found any possible workaround. It seems like while the signatures can be obtained from MetaGraphDef::signatures, there's no way to get the collection of variables and scopes when a model is loaded via SavedModelBundle::load or Session::from_saved_model.

EDIT 2020/12/20: What I wanted to do is to build the model in Python, and then train the model from Rust, so I'm looking for a way to keep all the trained parameters, either using SavedModel or checkpoints. Things I've tried so far:

  • Make the saving operations part of the saved model itself
    • Include the saving operations (tf.saved_model.save() or model.save()) in a @tf.function from the Python side. This gives me error messages like save is not supported inside a traced @tf.function so it doesn't work.
    • Include the Keras model.save_weights() in a @tf.function from the Python side. This gives me error messages like RuntimeError: Cannot get session inside Tensorflow graph function.
  • Try to find a way to get the parameters need for tensorflow::SavedModelBuilder from the Rust side, but this is beyond my knowledge about TensorFlow.
  • Try to use the tf.train.Saver operation, as suggested in #30, but this seems to be no longer available in TensorFlow 2.0 and onwards.

Another interesting finding is that in the experimental C API (c_api_experimental.h), while there's TF_CheckpointReader type for reading checkpoints and TF_LoadSessionFromSavedModel for reading saved models, there seems to be no functions available for saving states.

kotatsuyaki avatar Dec 19 '20 16:12 kotatsuyaki

I found a (quite hacky and ugly) workaround to this.

  • From the Python side, include a tf.train.Checkpoint.write call inside a @tf.function. Unlike tf.train.Checkpoint.save, the write function is written purely using tf operations, so it runs fine even in graph mode.
  • Save a concrete function of that together with the model. For example,
    tf.saved_model.save(my_model, modelpath, signatures={
      'ckpt_write': my_model.ckpt_write,
      # Other functions to be saved ...
    })
    
  • Run that ckpt_write function from the Rust side. For example,
    let mut args = tf::SessionRunArgs::new();
    // name of output can  be obtained by something like this:
    // bundle.meta_graph_def().signatures()["ckpt_write"].outputs()["output_0"].name()
    args.add_target(
        &graph
            .operation_by_name(&name_of_output_of_ckpt_write)?
            .unwrap(),
    );
    session.run(&mut args)?;
    

This makes it possible to save the state of a network trained from the Rust side, but I still haven't found any way to restore from the checkpoint using Rust. The restoring can, however, be done by:

  • Read the checkpoint using Python tf.train.Checkpoint.read function.
  • Save the whole model again (for Rust to read the whole model again using SavedModelBundle::load).

kotatsuyaki avatar Dec 22 '20 05:12 kotatsuyaki

It has been 2 years, is there a nice way of doing things by now?

Trolldemorted avatar Sep 15 '22 14:09 Trolldemorted

https://github.com/tensorflow/rust/blob/ae92c63703e9c197f0f1e21e9ce5f49a3f7d6bd0/src/saved_model.rs#L258

Well, in 2023, there's no support for other fields too.

AcrylicShrimp avatar Apr 23 '23 12:04 AcrylicShrimp

What a pity that this essential functionality is not yet covered properly.

bitmagier avatar Jul 31 '23 10:07 bitmagier